import numpy as np
from astropy.visualization import LuptonAsinhStretch
from matplotlib import pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 8),
                       layout='constrained')
ax = ax.ravel()

x = np.linspace(0, 1, 100)
stretches = (0.05, 0.1, 0.2, 0.5, 1, 5, 10)

Qs = (2, 5, 8, 10)
for i, Q in enumerate(Qs):
    for st in stretches:
        stretch = LuptonAsinhStretch(stretch=st, Q=Q)
        label = f'{st=}'
        ax[i].plot(x, stretch(x, clip=True), label=label)

    ax[i].axis('equal')
    ax[i].plot(x, x, ls='dotted', color='k', alpha=0.3)
    ax[i].set_xlim(0, 1)
    ax[i].set_ylim(0, 1)
    ax[i].set_xlabel('Input Value')
    ax[i].set_ylabel('Output Value')
    ax[i].set_title(f'{stretch.__class__.__name__}, {Q=}')
    ax[i].legend(loc='lower right', fontsize=8)