import numbers
import numpy as np
from astropy.wcs.wcsapi import BaseLowLevelWCS, wcs_info_str
from astropy.utils import isiterable
__all__ = ['sanitize_slices', 'SlicedLowLevelWCS']
[docs]def sanitize_slices(slices, ndim):
"""
Given a set of input
"""
if not isinstance(slices, (tuple, list)): # We just have a single int
slices = (slices,)
if len(slices) > ndim:
raise ValueError(
f"The dimensionality of the specified slice {slices} can not be greater "
f"than the dimensionality ({ndim}) of the wcs.")
if any((isiterable(s) for s in slices)):
raise IndexError("This slice is invalid, only integer or range slices are supported.")
slices = list(slices)
if Ellipsis in slices:
if slices.count(Ellipsis) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
# Replace the Ellipsis with the correct number of slice(None)s
e_ind = slices.index(Ellipsis)
slices.remove(Ellipsis)
n_e = ndim - len(slices)
for i in range(n_e):
ind = e_ind + i
slices.insert(ind, slice(None))
for i in range(ndim):
if i < len(slices):
slc = slices[i]
if isinstance(slc, slice):
if slc.step and slc.step != 1:
raise IndexError("Slicing WCS with a step is not supported.")
elif not isinstance(slc, numbers.Integral):
raise IndexError("Only integer or range slices are accepted.")
else:
slices.append(slice(None))
return slices
def combine_slices(slice1, slice2):
"""
Given two slices that can be applied to a 1-d array, find the resulting
slice that corresponds to the combination of both slices. We assume that
slice2 can be an integer, but slice1 cannot.
"""
if isinstance(slice1, slice) and slice1.step is not None:
raise ValueError('Only slices with steps of 1 are supported')
if isinstance(slice2, slice) and slice2.step is not None:
raise ValueError('Only slices with steps of 1 are supported')
if isinstance(slice2, numbers.Integral):
if slice1.start is None:
return slice2
else:
return slice2 + slice1.start
if slice1.start is None:
if slice1.stop is None:
return slice2
else:
if slice2.stop is None:
return slice(slice2.start, slice1.stop)
else:
return slice(slice2.start, min(slice1.stop, slice2.stop))
else:
if slice2.start is None:
start = slice1.start
else:
start = slice1.start + slice2.start
if slice2.stop is None:
stop = slice1.stop
else:
if slice1.start is None:
stop = slice2.stop
else:
stop = slice2.stop + slice1.start
if slice1.stop is not None:
stop = min(slice1.stop, stop)
return slice(start, stop)
[docs]class SlicedLowLevelWCS(BaseLowLevelWCS):
def __init__(self, wcs, slices):
slices = sanitize_slices(slices, wcs.pixel_n_dim)
if isinstance(wcs, SlicedLowLevelWCS):
# Here we combine the current slices with the previous slices
# to avoid ending up with many nested WCSes
self._wcs = wcs._wcs
slices_original = wcs._slices_array.copy()
for ipixel in range(wcs.pixel_n_dim):
ipixel_orig = wcs._wcs.pixel_n_dim - 1 - wcs._pixel_keep[ipixel]
ipixel_new = wcs.pixel_n_dim - 1 - ipixel
slices_original[ipixel_orig] = combine_slices(slices_original[ipixel_orig],
slices[ipixel_new])
self._slices_array = slices_original
else:
self._wcs = wcs
self._slices_array = slices
self._slices_pixel = self._slices_array[::-1]
# figure out which pixel dimensions have been kept, then use axis correlation
# matrix to figure out which world dims are kept
self._pixel_keep = np.nonzero([not isinstance(self._slices_pixel[ip], numbers.Integral)
for ip in range(self._wcs.pixel_n_dim)])[0]
# axis_correlation_matrix[world, pixel]
self._world_keep = np.nonzero(
self._wcs.axis_correlation_matrix[:, self._pixel_keep].any(axis=1))[0]
if len(self._pixel_keep) == 0 or len(self._world_keep) == 0:
raise ValueError("Cannot slice WCS: the resulting WCS should have "
"at least one pixel and one world dimension.")
@property
def pixel_n_dim(self):
return len(self._pixel_keep)
@property
def world_n_dim(self):
return len(self._world_keep)
@property
def world_axis_physical_types(self):
return [self._wcs.world_axis_physical_types[i] for i in self._world_keep]
@property
def world_axis_units(self):
return [self._wcs.world_axis_units[i] for i in self._world_keep]
@property
def pixel_axis_names(self):
return [self._wcs.pixel_axis_names[i] for i in self._pixel_keep]
@property
def world_axis_names(self):
return [self._wcs.world_axis_names[i] for i in self._world_keep]
[docs] def pixel_to_world_values(self, *pixel_arrays):
pixel_arrays = tuple(map(np.asanyarray, pixel_arrays))
pixel_arrays_new = []
ipix_curr = -1
for ipix in range(self._wcs.pixel_n_dim):
if isinstance(self._slices_pixel[ipix], int):
pixel_arrays_new.append(self._slices_pixel[ipix])
else:
ipix_curr += 1
if self._slices_pixel[ipix].start is not None:
pixel_arrays_new.append(pixel_arrays[ipix_curr] + self._slices_pixel[ipix].start)
else:
pixel_arrays_new.append(pixel_arrays[ipix_curr])
pixel_arrays_new = np.broadcast_arrays(*pixel_arrays_new)
world_arrays = self._wcs.pixel_to_world_values(*pixel_arrays_new)
# Detect the case of a length 0 array
if isinstance(world_arrays, np.ndarray) and not world_arrays.shape:
return world_arrays
if self._wcs.world_n_dim > 1:
# Select the dimensions of the original WCS we are keeping.
world_arrays = [world_arrays[iw] for iw in self._world_keep]
# If there is only one world dimension (after slicing) we shouldn't return a tuple.
if self.world_n_dim == 1:
world_arrays = world_arrays[0]
return world_arrays
[docs] def array_index_to_world_values(self, *index_arrays):
return self.pixel_to_world_values(*index_arrays[::-1])
[docs] def world_to_pixel_values(self, *world_arrays):
world_arrays = tuple(map(np.asanyarray, world_arrays))
world_arrays_new = []
iworld_curr = -1
for iworld in range(self._wcs.world_n_dim):
if iworld in self._world_keep:
iworld_curr += 1
world_arrays_new.append(world_arrays[iworld_curr])
else:
world_arrays_new.append(1.)
world_arrays_new = np.broadcast_arrays(*world_arrays_new)
pixel_arrays = list(self._wcs.world_to_pixel_values(*world_arrays_new))
for ipixel in range(self._wcs.pixel_n_dim):
if isinstance(self._slices_pixel[ipixel], slice) and self._slices_pixel[ipixel].start is not None:
pixel_arrays[ipixel] -= self._slices_pixel[ipixel].start
# Detect the case of a length 0 array
if isinstance(pixel_arrays, np.ndarray) and not pixel_arrays.shape:
return pixel_arrays
pixel = tuple(pixel_arrays[ip] for ip in self._pixel_keep)
if self.pixel_n_dim == 1 and self._wcs.pixel_n_dim > 1:
pixel = pixel[0]
return pixel
[docs] def world_to_array_index_values(self, *world_arrays):
pixel_arrays = self.world_to_pixel_values(*world_arrays, 0)[::-1]
array_indices = tuple(np.asarray(np.floor(pixel + 0.5), dtype=np.int_) for pixel in pixel_arrays)
return array_indices
@property
def world_axis_object_components(self):
return [self._wcs.world_axis_object_components[idx] for idx in self._world_keep]
@property
def world_axis_object_classes(self):
keys_keep = [item[0] for item in self.world_axis_object_components]
return dict([item for item in self._wcs.world_axis_object_classes.items() if item[0] in keys_keep])
@property
def array_shape(self):
if self._wcs.array_shape:
return np.broadcast_to(0, self._wcs.array_shape)[tuple(self._slices_array)].shape
@property
def pixel_shape(self):
if self.array_shape:
return self.array_shape[::-1]
@property
def pixel_bounds(self):
if self._wcs.pixel_bounds is None:
return None
bounds = []
for idx in self._pixel_keep:
if self._slices_pixel[idx].start is None:
bounds.append(self._wcs.pixel_bounds[idx])
else:
imin, imax = self._wcs.pixel_bounds[idx]
start = self._slices_pixel[idx].start
bounds.append((imin - start, imax - start))
return tuple(bounds)
@property
def axis_correlation_matrix(self):
return self._wcs.axis_correlation_matrix[self._world_keep][:, self._pixel_keep]
def __repr__(self):
return wcs_info_str(self)
def __str__(self):
return wcs_info_str(self)