Source code for astropy.io.registry.base

# Licensed under a 3-clause BSD style license - see LICENSE.rst

import contextlib
import re
import warnings
from operator import itemgetter

import numpy as np

__all__ = ["IORegistryError"]


[docs] class IORegistryError(Exception): """Custom error for registry clashes.""" pass
# ----------------------------------------------------------------------------- class _UnifiedIORegistryBase: """Base class for registries in Astropy's Unified IO. This base class provides identification functions and miscellaneous utilities. For an example how to build a registry subclass we suggest :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables read-only registries. These higher-level subclasses will probably serve better as a baseclass, for instance :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both :class:`~astropy.io.registry.UnifiedInputRegistry` and :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both reading from and writing to files. .. versionadded:: 5.0 """ def __init__(self): # registry of identifier functions self._identifiers = {} # what this class can do: e.g. 'read' &/or 'write' self._registries = {} self._registries["identify"] = { "attr": "_identifiers", "column": "Auto-identify", } self._registries_order = ("identify",) # match keys in `_registries` # If multiple formats are added to one class the update of the docs is quite # expensive. Classes for which the doc update is temporarily delayed are added # to this set. self._delayed_docs_classes = set() @property def available_registries(self): """Available registries. Returns ------- ``dict_keys`` """ return self._registries.keys() def get_formats(self, data_class=None, filter_on=None): """ Get the list of registered formats as a `~astropy.table.Table`. Parameters ---------- data_class : class or None, optional Filter readers/writer to match data class (default = all classes). filter_on : str or None, optional Which registry to show. E.g. "identify" If None search for both. Default is None. Returns ------- format_table : :class:`~astropy.table.Table` Table of available I/O formats. Raises ------ ValueError If ``filter_on`` is not None nor a registry name. """ from astropy.table import Table # set up the column names colnames = ( "Data class", "Format", *[self._registries[k]["column"] for k in self._registries_order], "Deprecated", ) i_dataclass = colnames.index("Data class") i_format = colnames.index("Format") i_regstart = colnames.index( self._registries[self._registries_order[0]]["column"] ) i_deprecated = colnames.index("Deprecated") # registries regs = set() for k in self._registries.keys() - {"identify"}: regs |= set(getattr(self, self._registries[k]["attr"])) format_classes = sorted(regs, key=itemgetter(0)) # the format classes from all registries except "identify" rows = [] for fmt, cls in format_classes: # see if can skip, else need to document in row if data_class is not None and not self._is_best_match( data_class, cls, format_classes ): continue # flags for each registry has_ = { k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No" for k, v in self._registries.items() } # Check if this is a short name (e.g. 'rdb') which is deprecated in # favor of the full 'ascii.rdb'. ascii_format_class = ("ascii." + fmt, cls) # deprecation flag deprecated = "Yes" if ascii_format_class in format_classes else "" # add to rows rows.append( ( cls.__name__, fmt, *[has_[n] for n in self._registries_order], deprecated, ) ) # filter_on can be in self_registries_order or None if str(filter_on).lower() in self._registries_order: index = self._registries_order.index(str(filter_on).lower()) rows = [row for row in rows if row[i_regstart + index] == "Yes"] elif filter_on is not None: raise ValueError( 'unrecognized value for "filter_on": {0}.\n' f"Allowed are {self._registries_order} and None." ) # Sorting the list of tuples is much faster than sorting it after the # table is created. (#5262) if rows: # Indices represent "Data Class", "Deprecated" and "Format". data = list( zip(*sorted(rows, key=itemgetter(i_dataclass, i_deprecated, i_format))) ) else: data = None # make table # need to filter elementwise comparison failure issue # https://github.com/numpy/numpy/issues/6784 with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=FutureWarning) format_table = Table(data, names=colnames) if not np.any(format_table["Deprecated"].data == "Yes"): format_table.remove_column("Deprecated") return format_table @contextlib.contextmanager def delay_doc_updates(self, cls): """Contextmanager to disable documentation updates when registering reader and writer. The documentation is only built once when the contextmanager exits. .. versionadded:: 1.3 Parameters ---------- cls : class Class for which the documentation updates should be delayed. Notes ----- Registering multiple readers and writers can cause significant overhead because the documentation of the corresponding ``read`` and ``write`` methods are build every time. Examples -------- see for example the source code of ``astropy.table.__init__``. """ self._delayed_docs_classes.add(cls) yield self._delayed_docs_classes.discard(cls) for method in self._registries.keys() - {"identify"}: self._update__doc__(cls, method) # ========================================================================= # Identifier methods def register_identifier(self, data_format, data_class, identifier, force=False): """ Associate an identifier function with a specific data type. Parameters ---------- data_format : str The data format identifier. This is the string that is used to specify the data type when reading/writing. data_class : class The class of the object that can be written. identifier : function A function that checks the argument specified to `read` or `write` to determine whether the input can be interpreted as a table of type ``data_format``. This function should take the following arguments: - ``origin``: A string ``"read"`` or ``"write"`` identifying whether the file is to be opened for reading or writing. - ``path``: The path to the file. - ``fileobj``: An open file object to read the file's contents, or `None` if the file could not be opened. - ``*args``: Positional arguments for the `read` or `write` function. - ``**kwargs``: Keyword arguments for the `read` or `write` function. One or both of ``path`` or ``fileobj`` may be `None`. If they are both `None`, the identifier will need to work from ``args[0]``. The function should return True if the input can be identified as being of format ``data_format``, and False otherwise. force : bool, optional Whether to override any existing function if already present. Default is ``False``. Examples -------- To set the identifier based on extensions, for formats that take a filename as a first argument, you can do for example .. code-block:: python from astropy.io.registry import register_identifier from astropy.table import Table def my_identifier(*args, **kwargs): return isinstance(args[0], str) and args[0].endswith('.tbl') register_identifier('ipac', Table, my_identifier) unregister_identifier('ipac', Table) """ if not (data_format, data_class) in self._identifiers or force: # noqa: E713 self._identifiers[(data_format, data_class)] = identifier else: raise IORegistryError( f"Identifier for format {data_format!r} and class" f" {data_class.__name__!r} is already defined" ) def unregister_identifier(self, data_format, data_class): """ Unregister an identifier function. Parameters ---------- data_format : str The data format identifier. data_class : class The class of the object that can be read/written. """ if (data_format, data_class) in self._identifiers: self._identifiers.pop((data_format, data_class)) else: raise IORegistryError( f"No identifier defined for format {data_format!r} and class" f" {data_class.__name__!r}" ) def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs): """Loop through identifiers to see which formats match. Parameters ---------- origin : str A string ``"read`` or ``"write"`` identifying whether the file is to be opened for reading or writing. data_class_required : object The specified class for the result of `read` or the class that is to be written. path : str or path-like or None The path to the file or None. fileobj : file-like or None. An open file object to read the file's contents, or ``None`` if the file could not be opened. args : sequence Positional arguments for the `read` or `write` function. Note that these must be provided as sequence. kwargs : dict-like Keyword arguments for the `read` or `write` function. Note that this parameter must be `dict`-like. Returns ------- valid_formats : list List of matching formats. """ valid_formats = [] for data_format, data_class in self._identifiers: if self._is_best_match(data_class_required, data_class, self._identifiers): if self._identifiers[(data_format, data_class)]( origin, path, fileobj, *args, **kwargs ): valid_formats.append(data_format) return valid_formats # ========================================================================= # Utils def _get_format_table_str(self, data_class, filter_on): """``get_formats()``, without column "Data class", as a str.""" format_table = self.get_formats(data_class, filter_on) format_table.remove_column("Data class") format_table_str = "\n".join(format_table.pformat(max_lines=-1)) return format_table_str def _is_best_match(self, class1, class2, format_classes): """Determine if class2 is the "best" match for class1 in the list of classes. It is assumed that (class2 in classes) is True. class2 is the best match if: - ``class1`` is a subclass of ``class2`` AND - ``class2`` is the nearest ancestor of ``class1`` that is in classes (which includes the case that ``class1 is class2``) """ if issubclass(class1, class2): classes = {cls for fmt, cls in format_classes} for parent in class1.__mro__: if parent is class2: # class2 is closest registered ancestor return True if parent in classes: # class2 was superseded return False return False def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs): """ Returns the first valid format that can be used to read/write the data in question. Mode can be either 'read' or 'write'. """ valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs) if len(valid_formats) == 0: format_table_str = self._get_format_table_str(cls, mode.capitalize()) raise IORegistryError( "Format could not be identified based on the" " file name or contents, please provide a" " 'format' argument.\n" f"The available formats are:\n{format_table_str}" ) elif len(valid_formats) > 1: return self._get_highest_priority_format(mode, cls, valid_formats) return valid_formats[0] def _get_highest_priority_format(self, mode, cls, valid_formats): """ Returns the reader or writer with the highest priority. If it is a tie, error. """ if mode == "read": format_dict = self._readers mode_loader = "reader" elif mode == "write": format_dict = self._writers mode_loader = "writer" best_formats = [] current_priority = -np.inf for format in valid_formats: try: _, priority = format_dict[(format, cls)] except KeyError: # We could throw an exception here, but get_reader/get_writer handle # this case better, instead maximally deprioritise the format. priority = -np.inf if priority == current_priority: best_formats.append(format) elif priority > current_priority: best_formats = [format] current_priority = priority if len(best_formats) > 1: raise IORegistryError( "Format is ambiguous - options are:" f" {', '.join(sorted(valid_formats, key=itemgetter(0)))}" ) return best_formats[0] def _update__doc__(self, data_class, readwrite): """ Update the docstring to include all the available readers / writers for the ``data_class.read``/``data_class.write`` functions (respectively). Don't update if the data_class does not have the relevant method. """ # abort if method "readwrite" isn't on data_class if not hasattr(data_class, readwrite): return from .interface import UnifiedReadWrite FORMATS_TEXT = "The available built-in formats are:" # Get the existing read or write method and its docstring class_readwrite_func = getattr(data_class, readwrite) if not isinstance(class_readwrite_func.__doc__, str): # No docstring--could just be test code, or possibly code compiled # without docstrings return lines = class_readwrite_func.__doc__.splitlines() # Find the location of the existing formats table if it exists sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line] if sep_indices: # Chop off the existing formats table, including the initial blank line chop_index = sep_indices[0] lines = lines[:chop_index] # Find the minimum indent, skipping the first line because it might be odd matches = [re.search(r"(\S)", line) for line in lines[1:]] left_indent = " " * min(match.start() for match in matches if match) # Get the available unified I/O formats for this class # Include only formats that have a reader, and drop the 'Data class' column format_table = self.get_formats(data_class, readwrite.capitalize()) format_table.remove_column("Data class") # Get the available formats as a table, then munge the output of pformat() # a bit and put it into the docstring. new_lines = format_table.pformat(max_lines=-1, max_width=80) table_rst_sep = re.sub("-", "=", new_lines[1]) new_lines[1] = table_rst_sep new_lines.insert(0, table_rst_sep) new_lines.append(table_rst_sep) # Check for deprecated names and include a warning at the end. if "Deprecated" in format_table.colnames: new_lines.extend( [ "", "Deprecated format names like ``aastex`` will be " "removed in a future version. Use the full ", "name (e.g. ``ascii.aastex``) instead.", ] ) new_lines = [FORMATS_TEXT, ""] + new_lines lines.extend([left_indent + line for line in new_lines]) # Depending on Python version and whether class_readwrite_func is # an instancemethod or classmethod, one of the following will work. if isinstance(class_readwrite_func, UnifiedReadWrite): class_readwrite_func.__class__.__doc__ = "\n".join(lines) else: try: class_readwrite_func.__doc__ = "\n".join(lines) except AttributeError: class_readwrite_func.__func__.__doc__ = "\n".join(lines)