Fitting models in parallel with N-dimensional data#

In some cases, you may want to fit a model many times to data. For example, you may have a spectral cube (with two celestial axes and one spectral axis) and you want to fit a 1D model (which could be either a simple Gaussian model or a complex compound model with multiple lines and a continuum) to each individual spectrum in the cube. Alternatively, you may have a cube with two celestial axes, one spectral axis, and one time axis, and you want to fit a 2D model to each 2D celestial plane in the cube. Provided each model fit can be treated as independent, there are significant performance benefits to carrying out these model fits in parallel.

The parallel_fit_dask() function is ideally suited to these use cases. It makes it simple to set up fitting of M-dimensional models to N-dimensional datasets and leverages the power of the dask package to efficiently parallelize the problem, running it either on multiple processes of a single machine or in a distributed environment. You do not need to know how to use dask in order to use this function, but you will need to make sure you have dask installed.

Note that the approach here is different from model sets which are described in Fitting Model Sets, which are a way of fitting a linear model with a vector of parameters to a data array, as in that specific case the fitting can be truly vectorized, and will likely not benefit from the approach described here.

Getting started#

To demonstrate the use of this function, we will work through a simple example of fitting a 1D model to a small spectral cube (if you are interested in accessing the file, you can find it at l1448_13co.fits, but the code below will automatically download it).

We start by downloading the cube and extracting the data and WCS:

>>> from astropy.wcs import WCS
>>> from astropy.io import fits
>>> from astropy.utils.data import get_pkg_data_filename
>>> filename = get_pkg_data_filename('l1448/l1448_13co.fits')  
>>> with fits.open(filename) as hdulist:
...     data = hdulist[0].data
...     wcs = WCS(hdulist[0].header)  

We extract a sub-cube spatially for the purpose of demonstration:

>>> data = data[:, 25:75, 35:85]
>>> wcs = wcs[:, 25:75, 35:85]

This is a cube of a star-formation region traced by the 13CO line. We can look at one of the channels:

>>> import matplotlib.pyplot as plt
>>> ax = plt.subplot(1, 1, 1, projection=wcs, slices=('x', 'y', 20))
>>> ax.imshow(data[20, :, :])  

(png, svg, pdf)

../_images/parallel-fitting-4.png

We can also extract a spectrum for one of the celestial positions:

>>> ax = plt.subplot(1, 1, 1, projection=wcs, slices=(5, 5, 'x'))
>>> ax.plot(data[:, 5, 5])  

(png, svg, pdf)

../_images/parallel-fitting-5.png

We now set up a model to fit this; we will use a simple Gaussian model, with some reasonable initial guesses for the parameters:

>>> from astropy import units as u
>>> from astropy.modeling.models import Gaussian1D
>>> model = Gaussian1D(amplitude=1 * u.one, mean=4000 * u.m / u.s, stddev=500 * u.m / u.s)

The data does not have any units in this case, so we use u.one as the unit, which indicates it is dimensionless.

Before fitting this to all spectra in the cube, it’s a good idea to test the model with at least one of the spectra manually. To do this, we need to extract the x-axis of the spectra:

>>> import numpy as np
>>> x = wcs.pixel_to_world(0, 0, np.arange(data.shape[0]))[1]
>>> x
<SpectralCoord
   (target: <ICRS Coordinate: (ra, dec, distance) in (deg, deg, kpc)
                (57.66, 0., 1000.)
             (pm_ra_cosdec, pm_dec, radial_velocity) in (mas / yr, mas / yr, km / s)
                (0., 0., 0.)>)
  [2528.19489695, 2594.61850695, 2661.04211695, 2727.46572695,
   2793.88933695, 2860.31294695, 2926.73655695, 2993.16016695,
   ...
   5716.52817695, 5782.95178695, 5849.37539695, 5915.79900695,
   5982.22261695] m / s>

We can now carry out the fit:

>>> from astropy.modeling.fitting import TRFLSQFitter
>>> fitter = TRFLSQFitter()
>>> model_fit_single = fitter(model, x, data[:, 5, 5])
>>> ax = plt.subplot(1, 1, 1)
>>> ax.plot(x, data[:, 5, 5], '.', label='data')  
>>> ax.plot(x, model(x), label='initial model')  
>>> ax.plot(x, model_fit_single(x), label='fitted model')  
>>> ax.legend()  

(png, svg, pdf)

../_images/parallel-fitting-9.png

The model seems to work! We can now use the parallel_fit_dask() function to fit all spectra in the cube:

>>> from astropy.modeling.fitting import parallel_fit_dask
>>> model_fit = parallel_fit_dask(model=model,
...                               fitter=fitter,
...                               data=data,
...                               world=wcs,
...                               fitting_axes=0,
...                               data_unit=u.one,
...                               scheduler='synchronous')

The arguments in this case are as follows:

  • model= is the initial model. While in our case the initial parameters were specified as scalars, it is possible to pass in a model that has array parameters if you want to have different initial parameters as a function of location in the dataset.

  • fitter= is the fitter instance.

  • data= is the N-dimensional dataset, in our case the 3D spectral cube.

  • world= provides information about the world coordinates for the fit, for example the spectral coordinates for a spectrum. This can be specified in different ways, but above we have chosen to pass in the WCS object for the dataset, from which the spectral axis coordinates will be extracted.

  • fitting_axes= specifies which axis or axes include the data to fit. In our example, we are fitting the spectra, which in NumPy notation is the first axis in the cube, so we specify fitting_axes=0.

  • data_unit= specifies the unit to use for the data. In our case, the data has no unit, but because we are using units for the spectral axis, we need to specify u.one here.

We can now take a look at the parameter maps:

>>> fig = plt.figure(figsize=(10, 5))
>>> ax = fig.add_subplot(1, 3, 1)
>>> ax.set_title('Amplitude')  
>>> ax.imshow(model_fit.amplitude.value, vmin=0, vmax=5, origin='lower')  
>>> ax = fig.add_subplot(1, 3, 2)
>>> ax.set_title('Mean')  
>>> ax.imshow(model_fit.mean.value, vmin=2500, vmax=6000, origin='lower')  
>>> ax = fig.add_subplot(1, 3, 3)
>>> ax.set_title('Standard deviation')  
>>> ax.imshow(model_fit.stddev.value, vmin=0, vmax=2000, origin='lower')  

(png, svg, pdf)

../_images/parallel-fitting-11.png

There are a number of pixels that appear to have issues. Inspecting the histogram of means, we can see that a lot of values are not at all in the spectral range we are fitting:

>>> _ = plt.hist(model_fit.mean.value.ravel(), bins=100)
>>> plt.yscale('log')
>>> plt.xlabel('mean')  
>>> plt.ylabel('number')  

(png, svg, pdf)

../_images/parallel-fitting-12.png

We can set the bounds on the mean and try the fit again

>>> model.mean.bounds = (3000, 6000) * u.km / u.s
>>> model_fit = parallel_fit_dask(model=model,
...                               fitter=fitter,
...                               data=data,
...                               world=wcs,
...                               fitting_axes=0,
...                               data_unit=u.one,
...                               scheduler='synchronous')

and we can visualize the results:

>>> fig = plt.figure(figsize=(10, 5))
>>> ax = fig.add_subplot(1, 3, 1)
>>> ax.set_title('Amplitude')  
>>> ax.imshow(model_fit.amplitude.value, vmin=0, vmax=5, origin='lower')  
>>> ax = fig.add_subplot(1, 3, 2)
>>> ax.set_title('Mean')  
>>> ax.imshow(model_fit.mean.value, vmin=2500, vmax=6000, origin='lower')  
>>> ax = fig.add_subplot(1, 3, 3)
>>> ax.set_title('Standard deviation')  
>>> ax.imshow(model_fit.stddev.value, vmin=0, vmax=2000, origin='lower')  

(png, svg, pdf)

../_images/parallel-fitting-14.png

The amplitude map no longer contains any problematic pixels.

World input#

The example above demonstrated that it is possible to pass in a astropy.wcs.WCS object to the world= argument in order to determine the world coordinates for the fit (e.g. the spectral axis values for a spectral fit). It is also possible to pass in a tuple of arrays - if you do this, the tuple should have one item per fitting axis. It is most efficient to pass in a tuple of 1D arrays, but if the world coordinates vary over the axes being iterated over, you can also pass in a tuple of N-d arrays, giving the coordinates of each individual pixel (it is also possible to pass in arrays that are not 1D but also not fully N-d as long as they can be broadcasted to the data shape).

Multiprocessing#

By default, parallel_fit_dask() will make use of multi-processing to parallelize the fitting. If you write a script to carry out the fitting, you will likely need to move your code inside a:

if __name__ == "__main__":

    ...

clause as otherwise Python will execute the whole code in the script many times, and potentially recursively, rather than just parallelizing the fitting.

Performance#

The parallel_fit_dask() function splits the data into chunks, each of which is then sent to a different process. The size of these chunks is critical to obtaining good performance. If we split the data into one chunk per fit, the process would be inefficient due to significant overhead from inter-process communication. Conversely, if we split the data into fewer chunks than there are available processes, we will not utilize all the available computational power. If we split the data into slightly more chunks than there are processes, inefficiencies can arise as well. For example, splitting the data into five chunks with four available processes means the four processes will first fit four chunks, and then a single process will be held up fitting the remaining chunk. Therefore, it is important to carefully consider how the data is split.

To control the splitting of the data, use the chunk_n_max= keyword argument. This determines how many individual fits will be carried out in each chunk. For example, when fitting a model to individual spectra in a spectral cube, setting chunk_n_max=100 means each chunk will contain 100 spectra. As a general guide, you will likely want to set this to be roughly the number of fits to be carried out in the data divided by several times the number of available processes. For example, if you need to fit 100,000 spectra and have 8 processes available, setting chunk_n_max=1000 would be reasonable. This configuration would break the data into 100 chunks, meaning each process will need to handle approximately a dozen chunks. Additionally, fitting 1,000 spectra per chunk will take enough time to avoid being dominated by communication overhead.

The default value for chunk_n_max is 500.

Diagnostics#

One of the challenges of fitting a model many different times is understanding what went wrong when issues arise. By default, if a fit fails with a warning or an exception, the parameters for that fit will be set to NaN, and no warning or exception will be shown to the user. However, it can be helpful to have more information, such as the specific error or exception that occurred.

You can control this by setting the diagnostics= argument. This allows you to choose whether to output information about:

  • Failed fits with errors (diagnostics='error'),

  • Fits with errors or warnings (diagnostics='error+warn'), or

  • All fits (diagnostics='all').

If the diagnostics option is specified, you will also need to specify diagnostics_path, which should be the path to a folder that will contain all the output. Each fit that needs to be output will be assigned a sub-folder named after the indices along the axes of the data (excluding the fitting axes). The output will include (if appropriate):

  • error.log, containing details of any exceptions that occurred

  • warn.log, containing any warnings

You may also want to automatically create a plot of the fit, inspect the data being fit, or examine the model. To do this, you can pass a function to diagnostics_callable. See parallel_fit_dask() for more information about the arguments this function should accept.

Schedulers#

By default, parallel_fit_dask() will make use of the 'processes' scheduler, which means that multiple processes on your local machine can be used. You can override the scheduler being used with the scheduler= keyword argument. You can either set this to the name of a scheduler (such as 'synchronous'), or you can set it to 'default' in order to make use of whatever is the currently active dask scheduler, which allows you for example to set up a dask.distributed scheduler.