# Source code for astropy.visualization.stretch

# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
Classes that deal with stretching, i.e. mapping a range of [0:1] values onto
another set of [0:1] values with a transformation.
"""

import numpy as np

from astropy.utils.decorators import deprecated_attribute

from .transform import BaseTransform, CompositeTransform

__all__ = [
"BaseStretch",
"LinearStretch",
"SqrtStretch",
"PowerStretch",
"PowerDistStretch",
"SquaredStretch",
"LogStretch",
"AsinhStretch",
"SinhStretch",
"HistEqStretch",
"ContrastBiasStretch",
"CompositeStretch",
]

def _logn(n, x, out=None):
"""Calculate the log base n of x."""
# We define this because numpy.emath.logn doesn't support the out
# keyword.
if out is None:
return np.log(x) / np.log(n)
else:
np.log(x, out=out)
np.true_divide(out, np.log(n), out=out)
return out

def _prepare(values, clip=True, out=None):
"""
Prepare the data by optionally clipping and copying, and return the
array that should be subsequently used for in-place calculations.
"""
if clip:
return np.clip(values, 0.0, 1.0, out=out)
else:
if out is None:
return np.array(values, copy=True)
else:
out[:] = np.asarray(values)
return out

[docs]
class BaseStretch(BaseTransform):
"""
Base class for the stretch classes, which when called with an array
of values in the range [0:1], returns an transformed array of values
also in the range [0:1].
"""

@property
def _supports_invalid_kw(self):
return False

return CompositeStretch(other, self)

[docs]
def __call__(self, values, clip=True, out=None):
"""
Transform values using this stretch.

Parameters
----------
values : array-like
The input values, which should already be normalized to the
[0:1] range.
clip : bool, optional
If True (default), values outside the [0:1] range are
clipped to the [0:1] range.
out : ndarray, optional
If specified, the output values will be placed in this array
(typically used for in-place calculations).

Returns
-------
result : ndarray
The transformed values.
"""

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""

[docs]
class LinearStretch(BaseStretch):
"""
A linear stretch with a slope and offset.

The stretch is given by:

.. math::
y = slope * x + intercept

Parameters
----------
slope : float, optional
The slope parameter used in the above formula.  Default is 1.
intercept : float, optional
The intercept parameter used in the above formula.  Default is 0.
"""

def __init__(self, slope=1, intercept=0):
super().__init__()
self.slope = slope
self.intercept = intercept

[docs]
def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
if self.slope != 1:
np.multiply(values, self.slope, out=values)
if self.intercept != 0:
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return LinearStretch(1.0 / self.slope, -self.intercept / self.slope)

[docs]
class SqrtStretch(BaseStretch):
r"""
A square root stretch.

The stretch is given by:

.. math::
y = \sqrt{x}
"""

@property
def _supports_invalid_kw(self):
return True

[docs]
def __call__(self, values, clip=True, out=None, invalid=None):
"""
Transform values using this stretch.

Parameters
----------
values : array-like
The input values, which should already be normalized to the
[0:1] range.
clip : bool, optional
If True (default), values outside the [0:1] range are
clipped to the [0:1] range.
out : ndarray, optional
If specified, the output values will be placed in this array
(typically used for in-place calculations).
invalid : None or float, optional
Value to assign NaN values generated by this class.  NaNs in
the input values array are not changed.  This option is
generally used with matplotlib normalization classes, where
the invalid value should map to the matplotlib colormap
"under" value (i.e., any finite value < 0).  If None, then
NaN values are not replaced.  This keyword has no effect if
clip=True.

Returns
-------
result : ndarray
The transformed values.
"""
values = _prepare(values, clip=clip, out=out)
replace_invalid = not clip and invalid is not None
with np.errstate(invalid="ignore"):
if replace_invalid:
idx = values < 0
np.sqrt(values, out=values)

if replace_invalid:
# Assign new NaN (i.e., NaN not in the original input
# values, but generated by this class) to the invalid value.
values[idx] = invalid

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return PowerStretch(2)

[docs]
class PowerStretch(BaseStretch):
r"""
A power stretch.

The stretch is given by:

.. math::
y = x^a

Parameters
----------
a : float
The power index (see the above formula).  a must be greater
than 0.
"""

power = deprecated_attribute("power", "6.0", alternative="a")

@property
def _supports_invalid_kw(self):
return True

def __init__(self, a):
super().__init__()
if a <= 0:
raise ValueError("a must be > 0")
self.a = a

[docs]
def __call__(self, values, clip=True, out=None, invalid=None):
"""
Transform values using this stretch.

Parameters
----------
values : array-like
The input values, which should already be normalized to the
[0:1] range.
clip : bool, optional
If True (default), values outside the [0:1] range are
clipped to the [0:1] range.
out : ndarray, optional
If specified, the output values will be placed in this array
(typically used for in-place calculations).
invalid : None or float, optional
Value to assign NaN values generated by this class.  NaNs in
the input values array are not changed.  This option is
generally used with matplotlib normalization classes, where
the invalid value should map to the matplotlib colormap
"under" value (i.e., any finite value < 0).  If None, then
NaN values are not replaced.  This keyword has no effect if
clip=True.

Returns
-------
result : ndarray
The transformed values.
"""
values = _prepare(values, clip=clip, out=out)
replace_invalid = (
not clip and invalid is not None and ((-1 < self.a < 0) or (0 < self.a < 1))
)
with np.errstate(invalid="ignore"):
if replace_invalid:
idx = values < 0
np.power(values, self.a, out=values)

if replace_invalid:
# Assign new NaN (i.e., NaN not in the original input
# values, but generated by this class) to the invalid value.
values[idx] = invalid

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return PowerStretch(1.0 / self.a)

[docs]
class PowerDistStretch(BaseStretch):
r"""
An alternative power stretch.

The stretch is given by:

.. math::
y = \frac{a^x - 1}{a - 1}

Parameters
----------
a : float, optional
The a parameter used in the above formula.  a must be
greater than or equal to 0, but cannot be set to 1.  Default is
1000.
"""

exp = deprecated_attribute("exp", "6.0", alternative="a")

def __init__(self, a=1000.0):
if a < 0 or a == 1:  # singularity
raise ValueError("a must be >= 0, but cannot be set to 1")
super().__init__()
self.a = a

[docs]
def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
np.power(self.a, values, out=values)
np.subtract(values, 1, out=values)
np.true_divide(values, self.a - 1.0, out=values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return InvertedPowerDistStretch(a=self.a)

class InvertedPowerDistStretch(BaseStretch):
r"""
Inverse transformation for
~astropy.image.scaling.PowerDistStretch.

The stretch is given by:

.. math::
y = \frac{\log(y (a-1) + 1)}{\log a}

Parameters
----------
a : float, optional
The a parameter used in the above formula.  a must be
greater than or equal to 0, but cannot be set to 1.  Default is
1000.
"""

exp = deprecated_attribute("exp", "6.0", alternative="a")

def __init__(self, a=1000.0):
if a < 0 or a == 1:  # singularity
raise ValueError("a must be >= 0, but cannot be set to 1")
super().__init__()
self.a = a

def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
np.multiply(values, self.a - 1.0, out=values)
_logn(self.a, values, out=values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return PowerDistStretch(a=self.a)

[docs]
class SquaredStretch(PowerStretch):
r"""
A convenience class for a power stretch of 2.

The stretch is given by:

.. math::
y = x^2
"""

def __init__(self):
super().__init__(2)

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return SqrtStretch()

[docs]
class LogStretch(BaseStretch):
r"""
A log stretch.

The stretch is given by:

.. math::
y = \frac{\log{(a x + 1)}}{\log{(a + 1)}}

Parameters
----------
a : float
The a parameter used in the above formula.  a must be
greater than 0.  Default is 1000.
"""

exp = deprecated_attribute("exp", "6.0", alternative="a")

@property
def _supports_invalid_kw(self):
return True

def __init__(self, a=1000.0):
super().__init__()
if a <= 0:  # singularity
raise ValueError("a must be > 0")
self.a = a

[docs]
def __call__(self, values, clip=True, out=None, invalid=None):
"""
Transform values using this stretch.

Parameters
----------
values : array-like
The input values, which should already be normalized to the
[0:1] range.
clip : bool, optional
If True (default), values outside the [0:1] range are
clipped to the [0:1] range.
out : ndarray, optional
If specified, the output values will be placed in this array
(typically used for in-place calculations).
invalid : None or float, optional
Value to assign NaN values generated by this class.  NaNs in
the input values array are not changed.  This option is
generally used with matplotlib normalization classes, where
the invalid value should map to the matplotlib colormap
"under" value (i.e., any finite value < 0).  If None, then
NaN values are not replaced.  This keyword has no effect if
clip=True.

Returns
-------
result : ndarray
The transformed values.
"""
values = _prepare(values, clip=clip, out=out)
replace_invalid = not clip and invalid is not None
with np.errstate(invalid="ignore"):
if replace_invalid:
idx = values < 0
np.multiply(values, self.a, out=values)
np.log(values, out=values)
np.true_divide(values, np.log(self.a + 1.0), out=values)

if replace_invalid:
# Assign new NaN (i.e., NaN not in the original input
# values, but generated by this class) to the invalid value.
values[idx] = invalid

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return InvertedLogStretch(self.a)

class InvertedLogStretch(BaseStretch):
r"""
Inverse transformation for ~astropy.image.scaling.LogStretch.

The stretch is given by:

.. math::
y = \frac{e^{y \log{a + 1}} - 1}{a} \\
y = \frac{e^{y} (a + 1) - 1}{a}

Parameters
----------
a : float, optional
The a parameter used in the above formula.  a must be
greater than 0.  Default is 1000.
"""

exp = deprecated_attribute("exp", "6.0", alternative="a")

def __init__(self, a):
super().__init__()
if a <= 0:  # singularity
raise ValueError("a must be > 0")
self.a = a

def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
np.multiply(values, np.log(self.a + 1.0), out=values)
np.exp(values, out=values)
np.subtract(values, 1.0, out=values)
np.true_divide(values, self.a, out=values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return LogStretch(self.a)

[docs]
class AsinhStretch(BaseStretch):
r"""
An asinh stretch.

The stretch is given by:

.. math::
y = \frac{{\rm asinh}(x / a)}{{\rm asinh}(1 / a)}.

Parameters
----------
a : float, optional
The a parameter used in the above formula. The value of this
parameter is where the asinh curve transitions from linear to
logarithmic behavior, expressed as a fraction of the normalized
image. The stretch becomes more linear as the a value is
increased. a must be greater than 0. Default is 0.1.
"""

def __init__(self, a=0.1):
super().__init__()
if a <= 0:
raise ValueError("a must be > 0")
self.a = a

[docs]
def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
np.true_divide(values, self.a, out=values)
np.arcsinh(values, out=values)
np.true_divide(values, np.arcsinh(1.0 / self.a), out=values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return SinhStretch(a=1.0 / np.arcsinh(1.0 / self.a))

[docs]
class SinhStretch(BaseStretch):
r"""
A sinh stretch.

The stretch is given by:

.. math::
y = \frac{{\rm sinh}(x / a)}{{\rm sinh}(1 / a)}

Parameters
----------
a : float, optional
The a parameter used in the above formula. The stretch
becomes more linear as the a value is increased. a must
be greater than 0. Default is 1/3.
"""

def __init__(self, a=1.0 / 3.0):
super().__init__()
if a <= 0:
raise ValueError("a must be > 0")
self.a = a

[docs]
def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
np.true_divide(values, self.a, out=values)
np.sinh(values, out=values)
np.true_divide(values, np.sinh(1.0 / self.a), out=values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return AsinhStretch(a=1.0 / np.sinh(1.0 / self.a))

[docs]
class HistEqStretch(BaseStretch):
"""
A histogram equalization stretch.

Parameters
----------
data : array-like
The data defining the equalization.
values : array-like, optional
The input image values, which should already be normalized to
the [0:1] range.
"""

def __init__(self, data, values=None):
# Assume data is not necessarily normalized at this point
self.data = np.sort(data.ravel())
self.data = self.data[np.isfinite(self.data)]
vmin = self.data.min()
vmax = self.data.max()
self.data = (self.data - vmin) / (vmax - vmin)

# Compute relative position of each pixel
if values is None:
self.values = np.linspace(0.0, 1.0, len(self.data))
else:
self.values = values

[docs]
def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
values[:] = np.interp(values, self.data, self.values)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return InvertedHistEqStretch(self.data, values=self.values)

class InvertedHistEqStretch(BaseStretch):
"""
Inverse transformation for ~astropy.image.scaling.HistEqStretch.

Parameters
----------
data : array-like
The data defining the equalization.
values : array-like, optional
The input image values, which should already be normalized to
the [0:1] range.
"""

def __init__(self, data, values=None):
self.data = data[np.isfinite(data)]
if values is None:
self.values = np.linspace(0.0, 1.0, len(self.data))
else:
self.values = values

def __call__(self, values, clip=True, out=None):
values = _prepare(values, clip=clip, out=out)
values[:] = np.interp(values, self.values, self.data)
return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return HistEqStretch(self.data, values=self.values)

[docs]
class ContrastBiasStretch(BaseStretch):
r"""
A stretch that takes into account contrast and bias.

The stretch is given by:

.. math::
y = (x - {\rm bias}) * {\rm contrast} + 0.5

and the output values are clipped to the [0:1] range.

Parameters
----------
contrast : float
The contrast parameter (see the above formula).

bias : float
The bias parameter (see the above formula).
"""

def __init__(self, contrast, bias):
super().__init__()
self.contrast = contrast
self.bias = bias

[docs]
def __call__(self, values, clip=True, out=None):
# As a special case here, we only clip *after* the
# transformation since it does not map [0:1] to [0:1]
values = _prepare(values, clip=False, out=out)

np.subtract(values, self.bias, out=values)
np.multiply(values, self.contrast, out=values)

if clip:
np.clip(values, 0, 1, out=values)

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return InvertedContrastBiasStretch(self.contrast, self.bias)

class InvertedContrastBiasStretch(BaseStretch):
"""
Inverse transformation for ContrastBiasStretch.

Parameters
----------
contrast : float
The contrast parameter (see
~astropy.visualization.ConstrastBiasStretch).

bias : float
The bias parameter (see
~astropy.visualization.ConstrastBiasStretch).
"""

def __init__(self, contrast, bias):
super().__init__()
self.contrast = contrast
self.bias = bias

def __call__(self, values, clip=True, out=None):
# As a special case here, we only clip *after* the
# transformation since it does not map [0:1] to [0:1]
values = _prepare(values, clip=False, out=out)
np.subtract(values, 0.5, out=values)
np.true_divide(values, self.contrast, out=values)

if clip:
np.clip(values, 0, 1, out=values)

return values

@property
def inverse(self):
"""A stretch object that performs the inverse operation."""
return ContrastBiasStretch(self.contrast, self.bias)

[docs]
class CompositeStretch(CompositeTransform, BaseStretch):
"""
A combination of two stretches.

Parameters
----------
stretch_1 : :class:astropy.visualization.BaseStretch
The first stretch to apply.
stretch_2 : :class:astropy.visualization.BaseStretch
The second stretch to apply.
"""

[docs]
def __call__(self, values, clip=True, out=None):
return self.transform_2(
self.transform_1(values, clip=clip, out=out), clip=clip, out=out
)