# Functions/classes for WCSAxes related to APE14 WCSes
from contextlib import contextmanager
import numpy as np
from astropy import units as u
from astropy.coordinates import ICRS, BaseCoordinateFrame, SkyCoord
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS
from .frame import EllipticalFrame, RectangularFrame, RectangularFrame1D
from .transforms import CurvedTransform
__all__ = [
"WCSPixel2WorldTransform",
"WCSWorld2PixelTransform",
"custom_ucd_coord_meta_mapping",
"transform_coord_meta_from_wcs",
]
IDENTITY = WCS(naxis=2)
IDENTITY.wcs.ctype = ["X", "Y"]
IDENTITY.wcs.crval = [0.0, 0.0]
IDENTITY.wcs.crpix = [1.0, 1.0]
IDENTITY.wcs.cdelt = [1.0, 1.0]
UCD_COORD_META_MAPPING = {
"lon": {"coord_type": "longitude"},
"lat": {"coord_type": "latitude"},
"ra": {"coord_type": "longitude", "format_unit": u.hourangle},
"dec": {"coord_type": "latitude"},
"alt": {"coord_type": "longitude"},
"az": {"coord_type": "latitude"},
"long": {"coord_type": "longitude"},
}
CUSTOM_UCD_COORD_META_MAPPING = {
"pos.helioprojective.lon": {
"coord_wrap": 180.0 * u.deg,
"format_unit": u.arcsec,
"coord_type": "longitude",
},
"pos.helioprojective.lat": {"format_unit": u.arcsec, "coord_type": "latitude"},
"pos.heliographic.stonyhurst.lon": {
"coord_wrap": 180.0 * u.deg,
"format_unit": u.deg,
"coord_type": "longitude",
},
"pos.heliographic.stonyhurst.lat": {"format_unit": u.deg, "coord_type": "latitude"},
"pos.heliographic.carrington.lon": {
"coord_wrap": 360.0 * u.deg,
"format_unit": u.deg,
"coord_type": "longitude",
},
"pos.heliographic.carrington.lat": {"format_unit": u.deg, "coord_type": "latitude"},
}
def transform_coord_meta_from_wcs(wcs, frame_class, slices=None):
if slices is not None:
slices = tuple(slices)
if wcs.pixel_n_dim > 2:
if slices is None:
raise ValueError(
"WCS has more than 2 pixel dimensions, so 'slices' should be set"
)
elif len(slices) != wcs.pixel_n_dim:
raise ValueError(
"'slices' should have as many elements as WCS "
f"has pixel dimensions (should be {wcs.pixel_n_dim})"
)
is_fits_wcs = isinstance(wcs, WCS) or (
isinstance(wcs, SlicedLowLevelWCS) and isinstance(wcs._wcs, WCS)
)
coord_meta = {}
coord_meta["name"] = []
coord_meta["type"] = []
coord_meta["wrap"] = []
coord_meta["unit"] = []
coord_meta["visible"] = []
coord_meta["format_unit"] = []
for idx in range(wcs.world_n_dim):
axis_type = wcs.world_axis_physical_types[idx]
axis_unit = u.Unit(wcs.world_axis_units[idx])
coord_wrap = None
format_unit = axis_unit
coord_type = "scalar"
dim_meta = {
"coord_type": coord_type,
"coord_wrap": coord_wrap,
"format_unit": format_unit,
"axis_unit": axis_unit,
}
if axis_type is not None:
axis_type_split = axis_type.split(".")
if len(axis_type_split):
axis_type_split[0] = axis_type_split[0].replace("custom:", "")
for ucd, meta in CUSTOM_UCD_COORD_META_MAPPING.items():
if ucd in axis_type:
dim_meta.update(meta)
break
else:
for ucd, meta in UCD_COORD_META_MAPPING.items():
if ucd == axis_type_split[-1]:
dim_meta.update(meta)
break
coord_meta["type"].append(dim_meta["coord_type"])
coord_meta["wrap"].append(dim_meta["coord_wrap"])
coord_meta["format_unit"].append(dim_meta["format_unit"])
coord_meta["unit"].append(dim_meta["axis_unit"])
# For FITS-WCS, for backward-compatibility, we need to make sure that we
# provide aliases based on CTYPE for the name.
if is_fits_wcs:
name = []
if isinstance(wcs, WCS):
name.append(wcs.wcs.ctype[idx].lower())
name.append(wcs.wcs.ctype[idx][:4].replace("-", "").lower())
elif isinstance(wcs, SlicedLowLevelWCS):
name.append(wcs._wcs.wcs.ctype[wcs._world_keep[idx]].lower())
name.append(
wcs._wcs.wcs.ctype[wcs._world_keep[idx]][:4]
.replace("-", "")
.lower()
)
if name[0] == name[1]:
name = name[0:1]
if axis_type:
if axis_type not in name:
name.insert(0, axis_type)
if wcs.world_axis_names and wcs.world_axis_names[idx]:
if wcs.world_axis_names[idx] not in name:
name.append(wcs.world_axis_names[idx])
name = tuple(name) if len(name) > 1 else name[0]
else:
name = axis_type or ""
if wcs.world_axis_names:
name = (
(name, wcs.world_axis_names[idx])
if wcs.world_axis_names[idx]
else name
)
coord_meta["name"].append(name)
coord_meta["default_axislabel_position"] = [""] * wcs.world_n_dim
coord_meta["default_ticklabel_position"] = [""] * wcs.world_n_dim
coord_meta["default_ticks_position"] = [""] * wcs.world_n_dim
# If the world axis has a name use it, else display the world axis physical type.
fallback_labels = [
name[0] if isinstance(name, (list, tuple)) else name
for name in coord_meta["name"]
]
coord_meta["default_axis_label"] = [
wcs.world_axis_names[i] or fallback_label
for i, fallback_label in enumerate(fallback_labels)
]
transform_wcs, invert_xy, world_map = apply_slices(wcs, slices)
transform = WCSPixel2WorldTransform(transform_wcs, invert_xy=invert_xy)
for i in range(len(coord_meta["type"])):
coord_meta["visible"].append(i in world_map)
inv_all_corr = [False] * wcs.world_n_dim
m = transform_wcs.axis_correlation_matrix.copy()
if invert_xy:
inv_all_corr = np.all(m, axis=1)
m = m[:, ::-1]
if frame_class in (RectangularFrame, RectangularFrame1D):
for index in world_map:
coord_meta["default_axislabel_position"][index] = "#"
coord_meta["default_ticklabel_position"][index] = "#"
coord_meta["default_ticks_position"][index] = "#"
# In the special and common case where the frame is rectangular and we
# are dealing with a 2-d WCS (after slicing) for RectangularFrame or a
# 1-d WCS for RectangularFrame1D, we show all ticks on all axes.
if (frame_class is RectangularFrame and len(world_map) == 2) or (
frame_class is RectangularFrame1D and len(world_map) == 1
):
for index in world_map:
coord_meta["default_ticks_position"][index] = frame_class.spine_names
elif frame_class is EllipticalFrame:
if "longitude" in coord_meta["type"]:
lon_idx = coord_meta["type"].index("longitude")
coord_meta["default_axislabel_position"][lon_idx] = "h"
coord_meta["default_ticklabel_position"][lon_idx] = "h"
coord_meta["default_ticks_position"][lon_idx] = "h"
if "latitude" in coord_meta["type"]:
lat_idx = coord_meta["type"].index("latitude")
coord_meta["default_axislabel_position"][lat_idx] = "c"
coord_meta["default_ticklabel_position"][lat_idx] = "c"
coord_meta["default_ticks_position"][lat_idx] = "c"
else:
for index in range(len(coord_meta["type"])):
if index in world_map:
coord_meta["default_axislabel_position"][index] = (
frame_class.spine_names
)
coord_meta["default_ticklabel_position"][index] = (
frame_class.spine_names
)
coord_meta["default_ticks_position"][index] = frame_class.spine_names
return transform, coord_meta
def apply_slices(wcs, slices):
"""
Take the input WCS and slices and return a sliced WCS for the transform and
a mapping of world axes in the sliced WCS to the input WCS.
"""
if isinstance(wcs, SlicedLowLevelWCS):
world_keep = list(wcs._world_keep)
else:
world_keep = list(range(wcs.world_n_dim))
# world_map is the index of the world axis in the input WCS for a given
# axis in the transform_wcs
world_map = list(range(wcs.world_n_dim))
transform_wcs = wcs
invert_xy = False
if slices is not None:
wcs_slice = list(slices)
wcs_slice[wcs_slice.index("x")] = slice(None)
if "y" in slices:
wcs_slice[wcs_slice.index("y")] = slice(None)
invert_xy = slices.index("x") > slices.index("y")
transform_wcs = SlicedLowLevelWCS(wcs, wcs_slice[::-1])
world_map = tuple(world_keep.index(i) for i in transform_wcs._world_keep)
return transform_wcs, invert_xy, world_map
def wcsapi_to_celestial_frame(wcs):
for cls, _, kwargs, *_ in wcs.world_axis_object_classes.values():
if issubclass(cls, SkyCoord):
return kwargs.get("frame", ICRS())
elif issubclass(cls, BaseCoordinateFrame):
return cls(**kwargs)
class WCSWorld2PixelTransform(CurvedTransform):
"""
WCS transformation from world to pixel coordinates.
"""
has_inverse = True
frame_in = None
def __init__(self, wcs, invert_xy=False):
super().__init__()
if wcs.pixel_n_dim > 2:
raise ValueError("Only pixel_n_dim =< 2 is supported")
self.wcs = wcs
self.invert_xy = invert_xy
self.frame_in = wcsapi_to_celestial_frame(wcs)
def __eq__(self, other):
return (
isinstance(other, type(self))
and self.wcs is other.wcs
and self.invert_xy == other.invert_xy
)
@property
def input_dims(self):
return self.wcs.world_n_dim
def transform(self, world):
# Convert to a list of arrays
world = list(world.T)
if len(world) != self.wcs.world_n_dim:
raise ValueError(
f"Expected {self.wcs.world_n_dim} world coordinates, got {len(world)} "
)
if len(world[0]) == 0:
pixel = np.zeros((0, 2))
else:
pixel = self.wcs.world_to_pixel_values(*world)
if self.invert_xy:
pixel = pixel[::-1]
pixel = np.array(pixel).T
return pixel
transform_non_affine = transform
def inverted(self):
"""
Return the inverse of the transform.
"""
return WCSPixel2WorldTransform(self.wcs, invert_xy=self.invert_xy)
class WCSPixel2WorldTransform(CurvedTransform):
"""
WCS transformation from pixel to world coordinates.
"""
has_inverse = True
def __init__(self, wcs, invert_xy=False):
super().__init__()
if wcs.pixel_n_dim > 2:
raise ValueError("Only pixel_n_dim =< 2 is supported")
self.wcs = wcs
self.invert_xy = invert_xy
self.frame_out = wcsapi_to_celestial_frame(wcs)
def __eq__(self, other):
return (
isinstance(other, type(self))
and self.wcs is other.wcs
and self.invert_xy == other.invert_xy
)
@property
def output_dims(self):
return self.wcs.world_n_dim
def transform(self, pixel):
# Convert to a list of arrays
pixel = list(pixel.T)
if len(pixel) != self.wcs.pixel_n_dim:
raise ValueError(
f"Expected {self.wcs.pixel_n_dim} world coordinates, got {len(pixel)} "
)
if self.invert_xy:
pixel = pixel[::-1]
if len(pixel[0]) == 0:
world = np.zeros((0, self.wcs.world_n_dim))
else:
world = self.wcs.pixel_to_world_values(*pixel)
if self.wcs.world_n_dim == 1:
world = [world]
world = np.array(world).T
return world
transform_non_affine = transform
def inverted(self):
"""
Return the inverse of the transform.
"""
return WCSWorld2PixelTransform(self.wcs, invert_xy=self.invert_xy)