Source code for astropy.visualization.basic_rgb

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Combine 3 images to produce properly-scaled RGB images with arbitrary scaling.

The three images must be aligned and have the same pixel scale and size.
"""

import numpy as np

from astropy.visualization.interval import ManualInterval
from astropy.visualization.stretch import LinearStretch

_OUTPUT_IMAGE_FORMATS = [float, np.float64, np.uint8]

__all__ = ["make_rgb"]


class RGBImageMapping:
    """
    Class to map red, blue, green images into either a normalized float or
    an 8-bit image, by performing optional clipping and applying
    a scaling function to each band independently.

    Parameters
    ----------
    interval : `~astropy.visualization.BaseInterval` subclass instance or array-like, optional
        The interval object to apply to the data (either a single instance or
        an array for R, G, B). Default is
        `~astropy.visualization.ManualInterval`.
    stretch : `~astropy.visualization.BaseStretch` subclass instance, optional
        The stretch object to apply to the data. Default is
        `~astropy.visualization.LinearStretch`.

    """

    def __init__(
        self, interval=ManualInterval(vmin=0, vmax=None), stretch=LinearStretch()
    ):
        try:
            len(interval)
        except TypeError:
            interval = 3 * [interval]
        if len(interval) != 3:
            raise ValueError("please provide 1 or 3 instances for interval.")

        self.intervals = interval
        self.stretch = stretch

    def make_rgb_image(self, image_r, image_g, image_b, output_dtype=np.uint8):
        """
        Convert 3 arrays, image_r, image_g, and image_b into a RGB image,
        either as an 8-bit per-channel or normalized image.

        The input images can be int or float, and in any range or bit-depth,
        but must have the same shape (NxM).

        Parameters
        ----------
        image_r : ndarray
            Image to map to red.
        image_g : ndarray
            Image to map to green.
        image_b : ndarray
            Image to map to blue.
        output_dtype : numpy scalar type, optional
            Image output format. Default is np.uint8.

        Returns
        -------
        RGBimage : ndarray
            RGB color image with the specified format as an NxMx3 numpy array.

        """
        if output_dtype not in _OUTPUT_IMAGE_FORMATS:
            raise ValueError(f"'output_dtype' must be one of {_OUTPUT_IMAGE_FORMATS}!")

        image_r = np.asarray(image_r)
        image_g = np.asarray(image_g)
        image_b = np.asarray(image_b)

        if (image_r.shape != image_g.shape) or (image_g.shape != image_b.shape):
            msg = "The image shapes must match. r: {}, g: {} b: {}"
            raise ValueError(msg.format(image_r.shape, image_g.shape, image_b.shape))

        image_rgb = self.apply_mappings(image_r, image_g, image_b)
        if np.issubdtype(output_dtype, float):
            conv_images = self._convert_images_to_float(image_rgb, output_dtype)
        elif np.issubdtype(output_dtype, np.unsignedinteger):
            conv_images = self._convert_images_to_uint(image_rgb, output_dtype)

        return np.dstack(conv_images)

    def apply_mappings(self, image_r, image_g, image_b):
        """
        Apply mapping stretch and intervals to convert images image_r, image_g,
        and image_b to a triplet of normalized images.

        Parameters
        ----------
        image_r : ndarray
            Intensity of image to be mapped to red
        image_g : ndarray
            Intensity of image to be mapped to green.
        image_b : ndarray
            Intensity of image to be mapped to blue.

        Returns
        -------
        image_rgb : ndarray
            Triplet of mapped images based on the specified (per-band)
            intervals and the stretch function

        """
        image_rgb = [image_r, image_g, image_b]
        for i, img in enumerate(image_rgb):
            # Using syntax from mpl_normalize.ImageNormalize,
            # but NOT using that class to avoid dependency on matplotlib.

            # Define vmin and vmax
            vmin, vmax = self.intervals[i].get_limits(img)

            # copy because of in-place operations after
            img = np.array(img, copy=True, dtype=float)

            # Normalize based on vmin and vmax
            np.subtract(img, vmin, out=img)
            np.true_divide(img, vmax - vmin, out=img)

            # Clip to the 0 to 1 range
            img = np.clip(img, 0.0, 1.0, out=img)

            # Stretch values
            img = self.stretch(img, out=img, clip=False)

            image_rgb[i] = img

        return np.asarray(image_rgb)

    def _convert_images_to_float(self, image_rgb, output_dtype):
        """
        Convert a triplet of normalized images to float.
        """
        return image_rgb.astype(output_dtype)

    def _convert_images_to_uint(self, image_rgb, output_dtype):
        """
        Convert a triplet of normalized images to unsigned integer images
        """
        pixmax = float(np.iinfo(output_dtype).max)

        image_rgb *= pixmax

        return image_rgb.astype(output_dtype)


[docs] def make_rgb( image_r, image_g, image_b, interval=ManualInterval(vmin=0, vmax=None), stretch=LinearStretch(), filename=None, output_dtype=np.uint8, ): """ Base class to return a Red/Green/Blue color image from 3 images using a specified stretch and interval, for each band *independently*. The input images can be int or float, and in any range or bit-depth, but must have the same shape (NxM). For a more detailed look at the use of this method, see the document :ref:`astropy:astropy-visualization-rgb`. Parameters ---------- image_r : ndarray Image to map to red. image_g : ndarray Image to map to green. image_b : ndarray Image to map to blue. interval : `~astropy.visualization.BaseInterval` subclass instance or array-like, optional The interval object to apply to the data (either a single instance or an array for R, G, B). Default is `~astropy.visualization.ManualInterval` with vmin=0. stretch : `~astropy.visualization.BaseStretch` subclass instance, optional The stretch object to apply to the data. Default is `~astropy.visualization.LinearStretch`. filename : str, optional Write the resulting RGB image to a file (file type determined from extension). output_dtype : numpy scalar type, optional Image output data type. Default is np.uint8. Returns ------- rgb : ndarray RGB (either float or integer with 8-bits per channel) color image as an NxMx3 numpy array. Notes ----- This procedure of clipping and then scaling is similar to the DS9 image algorithm (see the DS9 reference guide: http://ds9.si.edu/doc/ref/how.html). """ map_ = RGBImageMapping(interval=interval, stretch=stretch) rgb = map_.make_rgb_image(image_r, image_g, image_b, output_dtype=output_dtype) if filename: import matplotlib.image matplotlib.image.imsave(filename, rgb, origin="lower") return rgb