import numpy as np
from astropy.visualization import PowerDistStretch
from matplotlib import pyplot as plt

fig, ax = plt.subplots(figsize=(5, 5))

x = np.linspace(0, 1, 100)
a_vals = (0.001, 0.05, 0.3, 0.8, 1.2, 3, 10, 30, 100, 1000)
for a in a_vals:
    if a == 1000:
        lw = 3
    else:
        lw = 1
    stretch = PowerDistStretch(a)
    label = f'{a=}'
    ax.plot(x, stretch(x, clip=True), label=label, lw=lw)

ax.axis('equal')
ax.plot(x, x, ls='dotted', color='k', alpha=0.3)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel('Input Value')
ax.set_ylabel('Output Value')
ax.set_title(stretch.__class__.__name__)
ax.legend(loc='upper left', fontsize=8)