# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Functions to determine if a model is separable, i.e.
if the model outputs are independent.
It analyzes ``n_inputs``, ``n_outputs`` and the operators
in a compound model by stepping through the transforms
and creating a ``coord_matrix`` of shape (``n_outputs``, ``n_inputs``).
Each modeling operator is represented by a function which
takes two simple models (or two ``coord_matrix`` arrays) and
returns an array of shape (``n_outputs``, ``n_inputs``).
"""
import numpy as np
from .core import CompoundModel, Model, ModelDefinitionError
from .mappings import Mapping
__all__ = ["is_separable", "separability_matrix"]
[docs]
def is_separable(transform):
"""
A separability test for the outputs of a transform.
Parameters
----------
transform : `~astropy.modeling.core.Model`
A (compound) model.
Returns
-------
is_separable : ndarray
A boolean array with size ``transform.n_outputs`` where
each element indicates whether the output is independent
and the result of a separable transform.
Examples
--------
>>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D
>>> is_separable(Shift(1) & Shift(2) | Scale(1) & Scale(2))
array([ True, True]...)
>>> is_separable(Shift(1) & Shift(2) | Rotation2D(2))
array([False, False]...)
>>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \
Polynomial2D(1) & Polynomial2D(2))
array([False, False]...)
>>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]))
array([ True, True, True, True]...)
"""
if transform.n_inputs == 1 and transform.n_outputs > 1:
is_separable = np.array([False] * transform.n_outputs).T
return is_separable
separable_matrix = _separable(transform)
is_separable = separable_matrix.sum(1)
is_separable = np.where(is_separable != 1, False, True)
return is_separable
[docs]
def separability_matrix(transform):
"""
Compute the correlation between outputs and inputs.
Parameters
----------
transform : `~astropy.modeling.core.Model`
A (compound) model.
Returns
-------
separable_matrix : ndarray
A boolean correlation matrix of shape (n_outputs, n_inputs).
Indicates the dependence of outputs on inputs. For completely
independent outputs, the diagonal elements are True and
off-diagonal elements are False.
Examples
--------
>>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D
>>> separability_matrix(Shift(1) & Shift(2) | Scale(1) & Scale(2))
array([[ True, False], [False, True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Rotation2D(2))
array([[ True, True], [ True, True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \
Polynomial2D(1) & Polynomial2D(2))
array([[ True, True], [ True, True]]...)
>>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]))
array([[ True, False], [False, True], [ True, False], [False, True]]...)
"""
if transform.n_inputs == 1 and transform.n_outputs > 1:
return np.ones((transform.n_outputs, transform.n_inputs), dtype=np.bool_)
separable_matrix = _separable(transform)
separable_matrix = np.where(separable_matrix != 0, True, False)
return separable_matrix
def _compute_n_outputs(left, right):
"""
Compute the number of outputs of two models.
The two models are the left and right model to an operation in
the expression tree of a compound model.
Parameters
----------
left, right : `astropy.modeling.Model` or ndarray
If input is of an array, it is the output of `coord_matrix`.
"""
if isinstance(left, Model):
lnout = left.n_outputs
else:
lnout = left.shape[0]
if isinstance(right, Model):
rnout = right.n_outputs
else:
rnout = right.shape[0]
noutp = lnout + rnout
return noutp
def _arith_oper(left, right):
"""
Function corresponding to one of the arithmetic operators
['+', '-'. '*', '/', '**'].
This always returns a nonseparable output.
Parameters
----------
left, right : `astropy.modeling.Model` or ndarray
If input is of an array, it is the output of `coord_matrix`.
Returns
-------
result : ndarray
Result from this operation.
"""
def _n_inputs_outputs(input):
if isinstance(input, Model):
# Models have the same number of inputs and outputs.
n_outputs, n_inputs = input.n_outputs, input.n_inputs
else:
n_outputs, n_inputs = input.shape
return n_inputs, n_outputs
left_inputs, left_outputs = _n_inputs_outputs(left)
right_inputs, right_outputs = _n_inputs_outputs(right)
if left_inputs != right_inputs or left_outputs != right_outputs:
raise ModelDefinitionError(
"Unsupported operands for arithmetic operator: left"
f" (n_inputs={left_inputs}, n_outputs={left_outputs}) and right"
f" (n_inputs={right_inputs}, n_outputs={right_outputs}); models must have"
" the same n_inputs and the same n_outputs for this operator."
)
result = np.ones((left_outputs, left_inputs))
return result
def _coord_matrix(model, pos, noutp):
"""
Create an array representing inputs and outputs of a simple model.
The array has a shape (noutp, model.n_inputs).
Parameters
----------
model : `astropy.modeling.Model`
model
pos : str
Position of this model in the expression tree.
One of ['left', 'right'].
noutp : int
Number of outputs of the compound model of which the input model
is a left or right child.
"""
if isinstance(model, Mapping):
axes = []
for i in model.mapping:
axis = np.zeros((model.n_inputs,))
axis[i] = 1
axes.append(axis)
m = np.vstack(axes)
mat = np.zeros((noutp, model.n_inputs))
if pos == "left":
mat[: model.n_outputs, : model.n_inputs] = m
else:
mat[-model.n_outputs :, -model.n_inputs :] = m
return mat
if not model.separable:
# this does not work for more than 2 coordinates
mat = np.zeros((noutp, model.n_inputs))
if pos == "left":
mat[: model.n_outputs, : model.n_inputs] = 1
else:
mat[-model.n_outputs :, -model.n_inputs :] = 1
else:
mat = np.zeros((noutp, model.n_inputs))
for i in range(model.n_inputs):
mat[i, i] = 1
if pos == "right":
mat = np.roll(mat, (noutp - model.n_outputs))
return mat
def _cstack(left, right):
"""
Function corresponding to '&' operation.
Parameters
----------
left, right : `astropy.modeling.Model` or ndarray
If input is of an array, it is the output of `coord_matrix`.
Returns
-------
result : ndarray
Result from this operation.
"""
noutp = _compute_n_outputs(left, right)
if isinstance(left, Model):
cleft = _coord_matrix(left, "left", noutp)
else:
cleft = np.zeros((noutp, left.shape[1]))
cleft[: left.shape[0], : left.shape[1]] = left
if isinstance(right, Model):
cright = _coord_matrix(right, "right", noutp)
else:
cright = np.zeros((noutp, right.shape[1]))
cright[-right.shape[0] :, -right.shape[1] :] = right
return np.hstack([cleft, cright])
def _cdot(left, right):
"""
Function corresponding to "|" operation.
Parameters
----------
left, right : `astropy.modeling.Model` or ndarray
If input is of an array, it is the output of `coord_matrix`.
Returns
-------
result : ndarray
Result from this operation.
"""
left, right = right, left
def _n_inputs_outputs(input, position):
"""
Return ``n_inputs``, ``n_outputs`` for a model or coord_matrix.
"""
if isinstance(input, Model):
coords = _coord_matrix(input, position, input.n_outputs)
else:
coords = input
return coords
cleft = _n_inputs_outputs(left, "left")
cright = _n_inputs_outputs(right, "right")
try:
result = np.dot(cleft, cright)
except ValueError:
raise ModelDefinitionError(
'Models cannot be combined with the "|" operator; '
f"left coord_matrix is {cright}, right coord_matrix is {cleft}"
)
return result
def _separable(transform):
"""
Calculate the separability of outputs.
Parameters
----------
transform : `astropy.modeling.Model`
A transform (usually a compound model).
Returns :
is_separable : ndarray of dtype np.bool
An array of shape (transform.n_outputs,) of boolean type
Each element represents the separablity of the corresponding output.
"""
if (
transform_matrix := transform._calculate_separability_matrix()
) is not NotImplemented:
return transform_matrix
elif isinstance(transform, CompoundModel):
sepleft = _separable(transform.left)
sepright = _separable(transform.right)
return _operators[transform.op](sepleft, sepright)
elif isinstance(transform, Model):
return _coord_matrix(transform, "left", transform.n_outputs)
# Maps modeling operators to a function computing and represents the
# relationship of axes as an array of 0-es and 1-s
_operators = {
"&": _cstack,
"|": _cdot,
"+": _arith_oper,
"-": _arith_oper,
"*": _arith_oper,
"/": _arith_oper,
"**": _arith_oper,
}