# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Comparison functions for `astropy.cosmology.Cosmology`.
This module is **NOT** public API. To use these functions, import them from
the top-level namespace -- :mod:`astropy.cosmology`. This module will be
moved.
"""
from __future__ import annotations
import functools
import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from numpy import False_, True_, ndarray
from astropy import table
from astropy.cosmology.core import Cosmology
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any, TypeAlias
__all__: list[str] = [] # Nothing is scoped here
##############################################################################
# PARAMETERS
_FormatType: TypeAlias = bool | None | str
_FormatsType: TypeAlias = _FormatType | tuple[_FormatType, ...]
_COSMO_AOK: set[Any] = {None, True_, False_, "astropy.cosmology"}
# The numpy bool also catches real bool for ops "==" and "in"
##############################################################################
# UTILITIES
_CANT_BROADCAST: tuple[type, ...] = (table.Row, table.Table)
"""Things that cannot broadcast.
Have to deal with things that do not broadcast well. e.g.
`~astropy.table.Row` cannot be used in an array, even if ``dtype=object``
and will raise a segfault when used in a `numpy.ufunc`.
"""
@dataclass(frozen=True)
class _CosmologyWrapper:
"""A private wrapper class to hide things from :mod:`numpy`.
This should never be exposed to the user.
"""
__slots__ = ("wrapped",)
wrapped: Any
@functools.partial(np.frompyfunc, nin=2, nout=1)
def _parse_format(cosmo: Any, format: _FormatType, /) -> Cosmology:
"""Parse Cosmology-like input into Cosmologies, given a format hint.
Parameters
----------
cosmo : |Cosmology|-like, positional-only
|Cosmology| to parse.
format : bool or None or str, positional-only
Whether to allow, before equivalence is checked, the object to be
converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
to a |Cosmology|. `False` (default) will not allow conversion. `True` or
`None` will, and will use the auto-identification to try to infer the
correct format. A `str` is assumed to be the correct format to use when
converting.
Returns
-------
|Cosmology| or generator thereof
Raises
------
TypeError
If ``cosmo`` is not a |Cosmology| and ``format`` equals `False`.
TypeError
If ``cosmo`` is a |Cosmology| and ``format`` is not `None` or equal to
`True`.
"""
# Deal with private wrapper
if isinstance(cosmo, _CosmologyWrapper):
cosmo = cosmo.wrapped
# Shortcut if already a cosmology
if isinstance(cosmo, Cosmology):
if format not in _COSMO_AOK:
allowed = "/".join(map(str, _COSMO_AOK))
raise ValueError(
f"for parsing a Cosmology, 'format' must be {allowed}, not {format}"
)
return cosmo
# Convert, if allowed.
elif format == False_: # catches False and False_
raise TypeError(
f"if 'format' is False, arguments must be a Cosmology, not {cosmo}"
)
else:
format = None if format == True_ else format # str->str, None/True/True_->None
out = Cosmology.from_format(cosmo, format=format) # this can error!
return out
def _parse_formats(*cosmos: object, format: _FormatsType) -> ndarray:
"""Parse Cosmology-like to |Cosmology|, using provided formats.
``format`` is broadcast to match the shape of the cosmology arguments. Note
that the cosmology arguments are not broadcast against ``format``, so it
cannot determine the output shape.
Parameters
----------
*cosmos : |Cosmology|-like
The objects to compare. Must be convertible to |Cosmology|, as specified
by the corresponding ``format``.
format : bool or None or str or array-like thereof, positional-only
Whether to allow, before equivalence is checked, the object to be
converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
to a |Cosmology|. `False` (default) will not allow conversion. `True` or
`None` will, and will use the auto-identification to try to infer the
correct format. A `str` is assumed to be the correct format to use when
converting. Note ``format`` is broadcast as an object array to match the
shape of ``cosmos`` so ``format`` cannot determine the output shape.
Raises
------
TypeError
If any in ``cosmos`` is not a |Cosmology| and the corresponding
``format`` equals `False`.
"""
formats = np.broadcast_to(np.array(format, dtype=object), len(cosmos))
# parse each cosmo & format
# Have to deal with things that do not broadcast well.
# astropy.row cannot be used in an array, even if dtype=object
# and will raise a segfault when used in a ufunc.
wcosmos = [
c if not isinstance(c, _CANT_BROADCAST) else _CosmologyWrapper(c)
for c in cosmos
]
return _parse_format(wcosmos, formats)
def _comparison_decorator(pyfunc: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to make wrapper function that parses |Cosmology|-like inputs.
Parameters
----------
pyfunc : Python function object
An arbitrary Python function.
Returns
-------
callable[..., Any]
Wrapped `pyfunc`, as described above.
Notes
-----
All decorated functions should add the following to 'Parameters'.
format : bool or None or str or array-like thereof, optional keyword-only
Whether to allow the arguments to be converted to a |Cosmology|. This
allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
(default) will not allow conversion. `True` or `None` will, and will use
the auto-identification to try to infer the correct format. A `str` is
assumed to be the correct format to use when converting. Note ``format``
is broadcast as an object array to match the shape of ``cosmos`` so
``format`` cannot determine the output shape.
"""
sig = inspect.signature(pyfunc)
nin = sum(p.kind == 0 for p in sig.parameters.values())
# Make wrapper function that parses cosmology-like inputs
@functools.wraps(pyfunc)
def wrapper(*cosmos: Any, format: _FormatsType = False, **kwargs: Any) -> bool:
if len(cosmos) > nin:
raise TypeError(
f"{wrapper.__wrapped__.__name__} takes {nin} positional"
f" arguments but {len(cosmos)} were given"
)
# Parse cosmologies to format. Only do specified number.
cosmos = _parse_formats(*cosmos, format=format)
# Evaluate pyfunc, erroring if didn't match specified number.
result = wrapper.__wrapped__(*cosmos, **kwargs)
# Return, casting to correct type casting is possible.
return result
return wrapper
##############################################################################
# COMPARISON FUNCTIONS
[docs]
@_comparison_decorator
def cosmology_equal(
cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False
) -> bool:
r"""Return element-wise equality check on the cosmologies.
.. note::
Cosmologies are currently scalar in their parameters.
Parameters
----------
cosmo1, cosmo2 : |Cosmology|-like
The objects to compare. Must be convertible to |Cosmology|, as specified
by ``format``.
format : bool or None or str or tuple thereof, optional keyword-only
Whether to allow the arguments to be converted to a |Cosmology|. This
allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
(default) will not allow conversion. `True` or `None` will, and will use
the auto-identification to try to infer the correct format. A `str` is
assumed to be the correct format to use when converting. Note ``format``
is broadcast as an object array to match the shape of ``cosmos`` so
``format`` cannot determine the output shape.
allow_equivalent : bool, optional keyword-only
Whether to allow cosmologies to be equal even if not of the same class.
For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1`
and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|.
Examples
--------
Assuming the following imports
>>> import astropy.units as u
>>> from astropy.cosmology import FlatLambdaCDM
Two identical cosmologies are equal.
>>> cosmo1 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3)
>>> cosmo2 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3)
>>> cosmology_equal(cosmo1, cosmo2)
True
And cosmologies with different parameters are not.
>>> cosmo3 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.4)
>>> cosmology_equal(cosmo1, cosmo3)
False
Two cosmologies may be equivalent even if not of the same class. In these
examples the |LambdaCDM| has :attr:`~astropy.cosmology.LambdaCDM.Ode0` set
to the same value calculated in |FlatLambdaCDM|.
>>> from astropy.cosmology import LambdaCDM
>>> cosmo3 = LambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, 0.7)
>>> cosmology_equal(cosmo1, cosmo3)
False
>>> cosmology_equal(cosmo1, cosmo3, allow_equivalent=True)
True
While in this example, the cosmologies are not equivalent.
>>> cosmo4 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, Tcmb0=3 * u.K)
>>> cosmology_equal(cosmo3, cosmo4, allow_equivalent=True)
False
Also, using the keyword argument, the notion of equality is extended to any
Python object that can be converted to a |Cosmology|.
>>> mapping = cosmo2.to_format("mapping")
>>> cosmology_equal(cosmo1, mapping, format=True)
True
Either (or both) arguments can be |Cosmology|-like.
>>> cosmology_equal(mapping, cosmo2, format=True)
True
The list of valid formats, e.g. the |Table| in this example, may be checked
with ``Cosmology.from_format.list_formats()``.
As can be seen in the list of formats, not all formats can be
auto-identified by ``Cosmology.from_format.registry``. Objects of these
kinds can still be checked for equality, but the correct format string must
be used.
>>> yml = cosmo2.to_format("yaml")
>>> cosmology_equal(cosmo1, yml, format=(None, "yaml"))
True
This also works with an array of ``format`` matching the number of
cosmologies.
>>> cosmology_equal(mapping, yml, format=[True, "yaml"])
True
"""
# Check parameter equality
if not allow_equivalent:
eq = cosmo1 == cosmo2
else:
# Check parameter equivalence
# The options are: 1) same class & parameters; 2) same class, different
# parameters; 3) different classes, equivalent parameters; 4) different
# classes, different parameters. (1) & (3) => True, (2) & (4) => False.
eq = cosmo1.__equiv__(cosmo2)
if eq is NotImplemented:
eq = cosmo2.__equiv__(cosmo1) # that failed, try from 'other'
eq = False if eq is NotImplemented else eq
# TODO! include equality check of metadata
return eq
@_comparison_decorator
def _cosmology_not_equal(
cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False
) -> bool:
r"""Return element-wise cosmology non-equality check.
.. note::
Cosmologies are currently scalar in their parameters.
Parameters
----------
cosmo1, cosmo2 : |Cosmology|-like
The objects to compare. Must be convertible to |Cosmology|, as specified
by ``format``.
out : ndarray, None, optional
A location into which the result is stored. If provided, it must have a
shape that the inputs broadcast to. If not provided or None, a
freshly-allocated array is returned.
format : bool or None or str or tuple thereof, optional keyword-only
Whether to allow the arguments to be converted to a |Cosmology|. This
allows, e.g. a |Table| to be given instead a Cosmology. `False`
(default) will not allow conversion. `True` or `None` will, and will use
the auto-identification to try to infer the correct format. A `str` is
assumed to be the correct format to use when converting. ``format`` is
broadcast to match the shape of the cosmology arguments. Note that the
cosmology arguments are not broadcast against ``format``, so it cannot
determine the output shape.
allow_equivalent : bool, optional keyword-only
Whether to allow cosmologies to be equal even if not of the same class.
For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1`
and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|.
See Also
--------
astropy.cosmology.cosmology_equal
Element-wise equality check, with argument conversion to Cosmology.
"""
neq = not cosmology_equal(cosmo1, cosmo2, allow_equivalent=allow_equivalent)
# TODO! it might eventually be worth the speed boost to implement some of
# the internals of cosmology_equal here, but for now it's a hassle.
return neq