Source code for astropy.cosmology.funcs.comparison

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Comparison functions for `astropy.cosmology.Cosmology`.

This module is **NOT** public API. To use these functions, import them from
the top-level namespace -- :mod:`astropy.cosmology`. This module will be
moved.
"""

from __future__ import annotations

import functools
import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np
from numpy import False_, True_, ndarray

from astropy import table
from astropy.cosmology.core import Cosmology

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any, TypeAlias

__all__: list[str] = []  # Nothing is scoped here


##############################################################################
# PARAMETERS

_FormatType: TypeAlias = bool | None | str
_FormatsType: TypeAlias = _FormatType | tuple[_FormatType, ...]

_COSMO_AOK: set[Any] = {None, True_, False_, "astropy.cosmology"}
# The numpy bool also catches real bool for ops "==" and "in"


##############################################################################
# UTILITIES


_CANT_BROADCAST: tuple[type, ...] = (table.Row, table.Table)
"""Things that cannot broadcast.

Have to deal with things that do not broadcast well. e.g.
`~astropy.table.Row` cannot be used in an array, even if ``dtype=object``
and will raise a segfault when used in a `numpy.ufunc`.
"""


@dataclass(frozen=True)
class _CosmologyWrapper:
    """A private wrapper class to hide things from :mod:`numpy`.

    This should never be exposed to the user.
    """

    __slots__ = ("wrapped",)

    wrapped: Any


@functools.partial(np.frompyfunc, nin=2, nout=1)
def _parse_format(cosmo: Any, format: _FormatType, /) -> Cosmology:
    """Parse Cosmology-like input into Cosmologies, given a format hint.

    Parameters
    ----------
    cosmo : |Cosmology|-like, positional-only
        |Cosmology| to parse.
    format : bool or None or str, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting.

    Returns
    -------
    |Cosmology| or generator thereof

    Raises
    ------
    TypeError
        If ``cosmo`` is not a |Cosmology| and ``format`` equals `False`.
    TypeError
        If ``cosmo`` is a |Cosmology| and ``format`` is not `None` or equal to
        `True`.
    """
    # Deal with private wrapper
    if isinstance(cosmo, _CosmologyWrapper):
        cosmo = cosmo.wrapped

    # Shortcut if already a cosmology
    if isinstance(cosmo, Cosmology):
        if format not in _COSMO_AOK:
            allowed = "/".join(map(str, _COSMO_AOK))
            raise ValueError(
                f"for parsing a Cosmology, 'format' must be {allowed}, not {format}"
            )
        return cosmo
    # Convert, if allowed.
    elif format == False_:  # catches False and False_
        raise TypeError(
            f"if 'format' is False, arguments must be a Cosmology, not {cosmo}"
        )
    else:
        format = None if format == True_ else format  # str->str, None/True/True_->None
        out = Cosmology.from_format(cosmo, format=format)  # this can error!

    return out


def _parse_formats(*cosmos: object, format: _FormatsType) -> ndarray:
    """Parse Cosmology-like to |Cosmology|, using provided formats.

    ``format`` is broadcast to match the shape of the cosmology arguments. Note
    that the cosmology arguments are not broadcast against ``format``, so it
    cannot determine the output shape.

    Parameters
    ----------
    *cosmos : |Cosmology|-like
        The objects to compare. Must be convertible to |Cosmology|, as specified
        by the corresponding ``format``.

    format : bool or None or str or array-like thereof, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting. Note ``format`` is broadcast as an object array to match the
        shape of ``cosmos`` so ``format`` cannot determine the output shape.

    Raises
    ------
    TypeError
        If any in ``cosmos`` is not a |Cosmology| and the corresponding
        ``format`` equals `False`.
    """
    formats = np.broadcast_to(np.array(format, dtype=object), len(cosmos))
    # parse each cosmo & format

    # Have to deal with things that do not broadcast well.
    # astropy.row cannot be used in an array, even if dtype=object
    # and will raise a segfault when used in a ufunc.
    wcosmos = [
        c if not isinstance(c, _CANT_BROADCAST) else _CosmologyWrapper(c)
        for c in cosmos
    ]

    return _parse_format(wcosmos, formats)


def _comparison_decorator(pyfunc: Callable[..., Any]) -> Callable[..., Any]:
    """Decorator to make wrapper function that parses |Cosmology|-like inputs.

    Parameters
    ----------
    pyfunc : Python function object
        An arbitrary Python function.

    Returns
    -------
    callable[..., Any]
        Wrapped `pyfunc`, as described above.

    Notes
    -----
    All decorated functions should add the following to 'Parameters'.

    format : bool or None or str or array-like thereof, optional keyword-only
        Whether to allow the arguments to be converted to a |Cosmology|. This
        allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
        (default) will not allow conversion. `True` or `None` will, and will use
        the auto-identification to try to infer the correct format. A `str` is
        assumed to be the correct format to use when converting. Note ``format``
        is broadcast as an object array to match the shape of ``cosmos`` so
        ``format`` cannot determine the output shape.
    """
    sig = inspect.signature(pyfunc)
    nin = sum(p.kind == 0 for p in sig.parameters.values())

    # Make wrapper function that parses cosmology-like inputs
    @functools.wraps(pyfunc)
    def wrapper(*cosmos: Any, format: _FormatsType = False, **kwargs: Any) -> bool:
        if len(cosmos) > nin:
            raise TypeError(
                f"{wrapper.__wrapped__.__name__} takes {nin} positional"
                f" arguments but {len(cosmos)} were given"
            )
        # Parse cosmologies to format. Only do specified number.
        cosmos = _parse_formats(*cosmos, format=format)
        # Evaluate pyfunc, erroring if didn't match specified number.
        result = wrapper.__wrapped__(*cosmos, **kwargs)
        # Return, casting to correct type casting is possible.
        return result

    return wrapper


##############################################################################
# COMPARISON FUNCTIONS


[docs] @_comparison_decorator def cosmology_equal( cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False ) -> bool: r"""Return element-wise equality check on the cosmologies. .. note:: Cosmologies are currently scalar in their parameters. Parameters ---------- cosmo1, cosmo2 : |Cosmology|-like The objects to compare. Must be convertible to |Cosmology|, as specified by ``format``. format : bool or None or str or tuple thereof, optional keyword-only Whether to allow the arguments to be converted to a |Cosmology|. This allows, e.g. a |Table| to be given instead a |Cosmology|. `False` (default) will not allow conversion. `True` or `None` will, and will use the auto-identification to try to infer the correct format. A `str` is assumed to be the correct format to use when converting. Note ``format`` is broadcast as an object array to match the shape of ``cosmos`` so ``format`` cannot determine the output shape. allow_equivalent : bool, optional keyword-only Whether to allow cosmologies to be equal even if not of the same class. For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1` and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|. Examples -------- Assuming the following imports >>> import astropy.units as u >>> from astropy.cosmology import FlatLambdaCDM Two identical cosmologies are equal. >>> cosmo1 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3) >>> cosmo2 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3) >>> cosmology_equal(cosmo1, cosmo2) True And cosmologies with different parameters are not. >>> cosmo3 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.4) >>> cosmology_equal(cosmo1, cosmo3) False Two cosmologies may be equivalent even if not of the same class. In these examples the |LambdaCDM| has :attr:`~astropy.cosmology.LambdaCDM.Ode0` set to the same value calculated in |FlatLambdaCDM|. >>> from astropy.cosmology import LambdaCDM >>> cosmo3 = LambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, 0.7) >>> cosmology_equal(cosmo1, cosmo3) False >>> cosmology_equal(cosmo1, cosmo3, allow_equivalent=True) True While in this example, the cosmologies are not equivalent. >>> cosmo4 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, Tcmb0=3 * u.K) >>> cosmology_equal(cosmo3, cosmo4, allow_equivalent=True) False Also, using the keyword argument, the notion of equality is extended to any Python object that can be converted to a |Cosmology|. >>> mapping = cosmo2.to_format("mapping") >>> cosmology_equal(cosmo1, mapping, format=True) True Either (or both) arguments can be |Cosmology|-like. >>> cosmology_equal(mapping, cosmo2, format=True) True The list of valid formats, e.g. the |Table| in this example, may be checked with ``Cosmology.from_format.list_formats()``. As can be seen in the list of formats, not all formats can be auto-identified by ``Cosmology.from_format.registry``. Objects of these kinds can still be checked for equality, but the correct format string must be used. >>> yml = cosmo2.to_format("yaml") >>> cosmology_equal(cosmo1, yml, format=(None, "yaml")) True This also works with an array of ``format`` matching the number of cosmologies. >>> cosmology_equal(mapping, yml, format=[True, "yaml"]) True """ # Check parameter equality if not allow_equivalent: eq = cosmo1 == cosmo2 else: # Check parameter equivalence # The options are: 1) same class & parameters; 2) same class, different # parameters; 3) different classes, equivalent parameters; 4) different # classes, different parameters. (1) & (3) => True, (2) & (4) => False. eq = cosmo1.__equiv__(cosmo2) if eq is NotImplemented: eq = cosmo2.__equiv__(cosmo1) # that failed, try from 'other' eq = False if eq is NotImplemented else eq # TODO! include equality check of metadata return eq
@_comparison_decorator def _cosmology_not_equal( cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False ) -> bool: r"""Return element-wise cosmology non-equality check. .. note:: Cosmologies are currently scalar in their parameters. Parameters ---------- cosmo1, cosmo2 : |Cosmology|-like The objects to compare. Must be convertible to |Cosmology|, as specified by ``format``. out : ndarray, None, optional A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. format : bool or None or str or tuple thereof, optional keyword-only Whether to allow the arguments to be converted to a |Cosmology|. This allows, e.g. a |Table| to be given instead a Cosmology. `False` (default) will not allow conversion. `True` or `None` will, and will use the auto-identification to try to infer the correct format. A `str` is assumed to be the correct format to use when converting. ``format`` is broadcast to match the shape of the cosmology arguments. Note that the cosmology arguments are not broadcast against ``format``, so it cannot determine the output shape. allow_equivalent : bool, optional keyword-only Whether to allow cosmologies to be equal even if not of the same class. For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1` and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|. See Also -------- astropy.cosmology.cosmology_equal Element-wise equality check, with argument conversion to Cosmology. """ neq = not cosmology_equal(cosmo1, cosmo2, allow_equivalent=allow_equivalent) # TODO! it might eventually be worth the speed boost to implement some of # the internals of cosmology_equal here, but for now it's a hassle. return neq