# Source code for astropy.modeling.separable

# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
Functions to determine if a model is separable, i.e.
if the model outputs are independent.

It analyzes n_inputs, n_outputs and the operators
in a compound model by stepping through the transforms
and creating a coord_matrix of shape (n_outputs, n_inputs).

Each modeling operator is represented by a function which
takes two simple models (or two coord_matrix arrays) and
returns an array of shape (n_outputs, n_inputs).

"""

import numpy as np

from .core import Model, ModelDefinitionError, CompoundModel
from .mappings import Mapping

__all__ = ["is_separable", "separability_matrix"]

[docs]def is_separable(transform):
"""
A separability test for the outputs of a transform.

Parameters
----------
transform : ~astropy.modeling.core.Model
A (compound) model.

Returns
-------
is_separable : ndarray
A boolean array with size transform.n_outputs where
each element indicates whether the output is independent
and the result of a separable transform.

Examples
--------
>>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D
>>> is_separable(Shift(1) & Shift(2) | Scale(1) & Scale(2))
array([ True,  True]...)
>>> is_separable(Shift(1) & Shift(2) | Rotation2D(2))
array([False, False]...)
>>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \
Polynomial2D(1) & Polynomial2D(2))
array([False, False]...)
>>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]))
array([ True,  True,  True,  True]...)

"""
if transform.n_inputs == 1 and transform.n_outputs > 1:
is_separable = np.array([False] * transform.n_outputs).T
return is_separable
separable_matrix = _separable(transform)
is_separable = separable_matrix.sum(1)
is_separable = np.where(is_separable != 1, False, True)
return is_separable

[docs]def separability_matrix(transform):
"""
Compute the correlation between outputs and inputs.

Parameters
----------
transform : ~astropy.modeling.core.Model
A (compound) model.

Returns
-------
separable_matrix : ndarray
A boolean correlation matrix of shape (n_outputs, n_inputs).
Indicates the dependence of outputs on inputs. For completely
independent outputs, the diagonal elements are True and
off-diagonal elements are False.

Examples
--------
>>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D
>>> separability_matrix(Shift(1) & Shift(2) | Scale(1) & Scale(2))
array([[ True, False], [False,  True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Rotation2D(2))
array([[ True,  True], [ True,  True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \
Polynomial2D(1) & Polynomial2D(2))
array([[ True,  True], [ True,  True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]))
array([[ True, False], [False,  True], [ True, False], [False,  True]]...)

"""
if transform.n_inputs == 1 and transform.n_outputs > 1:
return np.ones((transform.n_outputs, transform.n_inputs),
dtype=np.bool_)
separable_matrix = _separable(transform)
separable_matrix = np.where(separable_matrix != 0, True, False)
return separable_matrix

def _compute_n_outputs(left, right):
"""
Compute the number of outputs of two models.

The two models are the left and right model to an operation in
the expression tree of a compound model.

Parameters
----------
left, right : astropy.modeling.Model or ndarray
If input is of an array, it is the output of coord_matrix.

"""
if isinstance(left, Model):
lnout = left.n_outputs
else:
lnout = left.shape[0]
if isinstance(right, Model):
rnout = right.n_outputs
else:
rnout = right.shape[0]
noutp = lnout + rnout
return noutp

def _arith_oper(left, right):
"""
Function corresponding to one of the arithmetic operators
['+', '-'. '*', '/', '**'].

This always returns a nonseparable output.

Parameters
----------
left, right : astropy.modeling.Model or ndarray
If input is of an array, it is the output of coord_matrix.

Returns
-------
result : ndarray
Result from this operation.
"""
# models have the same number of inputs and outputs
def _n_inputs_outputs(input):
if isinstance(input, Model):
n_outputs, n_inputs = input.n_outputs, input.n_inputs
else:
n_outputs, n_inputs = input.shape
return n_inputs, n_outputs

left_inputs, left_outputs = _n_inputs_outputs(left)
right_inputs, right_outputs = _n_inputs_outputs(right)

if left_inputs != right_inputs or left_outputs != right_outputs:
raise ModelDefinitionError(
"Unsupported operands for arithmetic operator: left (n_inputs={}, "
"n_outputs={}) and right (n_inputs={}, n_outputs={}); "
"models must have the same n_inputs and the same "
"n_outputs for this operator.".format(
left_inputs, left_outputs, right_inputs, right_outputs))

result = np.ones((left_outputs, left_inputs))
return result

def _coord_matrix(model, pos, noutp):
"""
Create an array representing inputs and outputs of a simple model.

The array has a shape (noutp, model.n_inputs).

Parameters
----------
model : astropy.modeling.Model
model
pos : str
Position of this model in the expression tree.
One of ['left', 'right'].
noutp : int
Number of outputs of the compound model of which the input model
is a left or right child.

"""
if isinstance(model, Mapping):
axes = []
for i in model.mapping:
axis = np.zeros((model.n_inputs,))
axis[i] = 1
axes.append(axis)
m = np.vstack(axes)
mat = np.zeros((noutp, model.n_inputs))
if pos == 'left':
mat[: model.n_outputs, :model.n_inputs] = m
else:
mat[-model.n_outputs:, -model.n_inputs:] = m
return mat
if not model.separable:
# this does not work for more than 2 coordinates
mat = np.zeros((noutp, model.n_inputs))
if pos == 'left':
mat[:model.n_outputs, : model.n_inputs] = 1
else:
mat[-model.n_outputs:, -model.n_inputs:] = 1
else:
mat = np.zeros((noutp, model.n_inputs))

for i in range(model.n_inputs):
mat[i, i] = 1
if pos == 'right':
mat = np.roll(mat, (noutp - model.n_outputs))
return mat

def _cstack(left, right):
"""
Function corresponding to '&' operation.

Parameters
----------
left, right : astropy.modeling.Model or ndarray
If input is of an array, it is the output of coord_matrix.

Returns
-------
result : ndarray
Result from this operation.

"""
noutp = _compute_n_outputs(left, right)

if isinstance(left, Model):
cleft = _coord_matrix(left, 'left', noutp)
else:
cleft = np.zeros((noutp, left.shape[1]))
cleft[: left.shape[0], : left.shape[1]] = left
if isinstance(right, Model):
cright = _coord_matrix(right, 'right', noutp)
else:
cright = np.zeros((noutp, right.shape[1]))
cright[-right.shape[0]:, -right.shape[1]:] = 1

return np.hstack([cleft, cright])

def _cdot(left, right):
"""
Function corresponding to "|" operation.

Parameters
----------
left, right : astropy.modeling.Model or ndarray
If input is of an array, it is the output of coord_matrix.

Returns
-------
result : ndarray
Result from this operation.
"""

left, right = right, left

def _n_inputs_outputs(input, position):
"""
Return n_inputs, n_outputs for a model or coord_matrix.
"""
if isinstance(input, Model):
coords = _coord_matrix(input, position, input.n_outputs)
else:
coords = input
return coords

cleft = _n_inputs_outputs(left, 'left')
cright = _n_inputs_outputs(right, 'right')

try:
result = np.dot(cleft, cright)
except ValueError:
raise ModelDefinitionError(
'Models cannot be combined with the "|" operator; '
'left coord_matrix is {}, right coord_matrix is {}'.format(
cright, cleft))
return result

def _separable(transform):
"""
Calculate the separability of outputs.

Parameters
----------
transform : astropy.modeling.Model
A transform (usually a compound model).

Returns :
is_separable : ndarray of dtype np.bool
An array of shape (transform.n_outputs,) of boolean type
Each element represents the separablity of the corresponding output.
"""
if isinstance(transform, CompoundModel):
sepleft = _separable(transform.left)
sepright = _separable(transform.right)
return _operators[transform.op](sepleft, sepright)
elif isinstance(transform, Model):
return _coord_matrix(transform, 'left', transform.n_outputs)

# Maps modeling operators to a function computing and represents the
# relationship of axes as an array of 0-es and 1-s
_operators = {'&': _cstack, '|': _cdot, '+': _arith_oper, '-': _arith_oper,
'*': _arith_oper, '/': _arith_oper, '**': _arith_oper}