"""The ``plotting`` module provides functions for generating plots of
analysis results that are visually consistent and easily reproducible.
Plotting Function Summaries
---------------------------
.. autosummary::
   :nosignatures:
   plot_cosmology_fit
   plot_delta_colors
   plot_delta_mag_vs_pwv
   plot_delta_mag_vs_z
   plot_delta_mu
   plot_derivative_mag_vs_z
   plot_fitted_params
   plot_magnitude
   plot_pwv_mag_effects
   plot_residuals_on_sky
   plot_spectral_template
   plot_year_pwv_vs_time
   compare_prop_effects
   plot_transmission_variation
   plot_flux_variation
   plot_delta_sn_flux
Module Docs
-----------
"""
from copy import copy
from datetime import datetime, timedelta
from typing import *
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import sncosmo
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.cosmology import FlatwCDM
from astropy.cosmology.core import Cosmology
from astropy.time import Time
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from pwv_kpno.defaults import v1_transmission
from pytz import utc
from . import constants as const
from . import models
from .models import PWVTransmissionModel
from .types import Numeric
def _multi_line_plot(
        x_arr: np.ndarray, y_arr: np.ndarray, z_arr: np.ndarray, axis: plt.Axes, label: Optional[str] = None
) -> None:
    """Plot a 2d y array vs a 1d x array.
    Lines are color coded according to values of the 2d ``z_arr`` array
    Args:
        x_arr: A 1d array with x-Values
        y_arr: A 2d array with y values for multiple lines
        z_arr: A 2d array with z (color) values for x,y pair
        axis: Axis to plot on
        label: Optional label to format with ``z_arr`` values
    """
    # noinspection PyUnresolvedReferences
    colors = plt.cm.viridis(np.linspace(0, 1, len(z_arr)))
    for z, y, color in zip(z_arr, y_arr, colors):
        if label:
            axis.plot(x_arr, y, label=label.format(z), c=color)
        else:
            axis.plot(x_arr, y, c=color)
    axis.set_xlim(x_arr[0], x_arr[-1])
[docs]def plot_delta_mag_vs_z(
        pwv_arr: np.ndarray,
        z_arr: np.ndarray,
        delta_mag_arr: np.ndarray,
        axis: Optional[plt.axis] = None,
        label: Optional[str] = None
) -> None:
    """Single panel, multi-line plot of change in magnitude vs redshift color coded by PWV
    Args:
        pwv_arr: Array of PWV values
        z_arr: Array of redshift values
        delta_mag_arr: Array of delta mag values
        axis: Optionally plot on a given axis
        label: Optional label to format with PWV
    """
    if axis is None:
        axis = plt.gca()
    _multi_line_plot(z_arr, delta_mag_arr, pwv_arr, axis, label)
    axis.set_xlabel('Redshift', fontsize=20)
    axis.set_xlim(min(z_arr), max(z_arr))
    axis.set_ylabel(r'$\Delta m$', fontsize=20) 
[docs]def plot_delta_mag_vs_pwv(
        pwv_arr: np.ndarray,
        z_arr: np.ndarray,
        delta_mag_arr: np.ndarray,
        axis: Optional[plt.axis] = None,
        label: Optional[str] = None
) -> None:
    """Single panel, multi-line plot for change in magnitude vs PWV color coded by redshift
    Args:
        pwv_arr: Array of PWV values
        z_arr: Array of redshift values
        delta_mag_arr: Array of delta mag values
        axis: Optionally plot on a given axis
        label: Optional label to format with redshift
    """
    if axis is None:
        axis = plt.gca()
    _multi_line_plot(pwv_arr, delta_mag_arr.T, z_arr, axis, label)
    axis.set_xlabel('PWV', fontsize=20)
    axis.set_xlim(min(pwv_arr), max(pwv_arr))
    axis.set_ylabel(r'$\Delta m$', fontsize=20) 
# noinspection PyUnusedLocal
[docs]def plot_derivative_mag_vs_z(
        pwv_arr: np.ndarray, z_arr: np.ndarray, slope_arr: np.ndarray, axis: Optional[plt.axis] = None
) -> None:
    """Single panel, multi-line plot of slope in delta magnitude vs redshift color coded by PWV
    Args:
        pwv_arr: Array of PWV values
        z_arr: Array of redshift values
        slope_arr: Slope of delta mag at reference PWV
        axis: Optionally plot on a given axis
    """
    if axis is None:
        axis = plt.gca()
    axis.plot(z_arr, slope_arr)
    axis.set_xlabel('Redshift', fontsize=20)
    axis.set_xlim(min(z_arr), max(z_arr))
    axis.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{PWV = 4 mm}$', fontsize=20) 
# noinspection PyUnboundLocalVariable
[docs]def plot_pwv_mag_effects(
        pwv_arr: np.ndarray,
        z_arr: np.ndarray,
        delta_mag: dict,
        slopes: np.ndarray,
        bands: List[str],
        figsize: Tuple[Numeric, Numeric] = (10, 8)
) -> Tuple[plt.Figure, plt.Axes]:
    """Multi panel plot with a column for each band and rows for the change in magnitude vs pwv and redshift parameters
    ``delta_mag`` is expected to have band names as keys, and 2d arrays as
    values. Each array should represent the change in magnitude for each
    given PWV and redshift
    Args:
        pwv_arr: PWV values used in the calculation
        z_arr: Redshift values used in the calculation
        delta_mag: Dictionary with delta mag for each band
        slopes: Slope in delta_mag for each redshift
        bands: Order of bands to plot
        figsize: The size of the figure
    Returns:
        The matplotlib figure and axis
    """
    fig, axes = plt.subplots(3, len(delta_mag), figsize=figsize)
    top_reference_ax = axes[0, 0]
    middle_reference_ax = axes[1, 0]
    bottom_reference_ax = axes[2, 0]
    # Plot data
    for band, axes_column in zip(bands, axes.T):
        top_ax, middle_ax, bottom_ax = axes_column
        # First row
        plot_delta_mag_vs_z(pwv_arr, z_arr, delta_mag[band], top_ax, label='{:g} mm')
        top_ax.axhline(0, linestyle='--', color='k', label='4 mm')
        top_ax.set_title(f'{band[-1]}-band')
        top_ax.set_xlabel('Redshift', fontsize=12)
        top_ax.set_ylabel('')
        # Middle row
        plot_delta_mag_vs_pwv(pwv_arr, z_arr, delta_mag[band], middle_ax, label='z = {:g}')
        top_ax.axvline(4, linestyle='--', color='k')
        middle_ax.set_xlabel('PWV', fontsize=12)
        middle_ax.set_ylabel('')
        # Bottom row
        plot_derivative_mag_vs_z(pwv_arr, z_arr, slopes[band], bottom_ax)
        bottom_ax.set_xlabel('Redshift', fontsize=12)
        bottom_ax.set_ylabel('')
        # Share axes
        top_ax.get_shared_y_axes().join(top_ax, top_reference_ax)
        middle_ax.get_shared_y_axes().join(middle_ax, middle_reference_ax)
        bottom_ax.get_shared_y_axes().join(bottom_ax, bottom_reference_ax)
        top_ax.get_shared_x_axes().join(top_ax, top_reference_ax)
        bottom_ax.get_shared_x_axes().join(bottom_ax, top_reference_ax)
    top_reference_ax.autoscale()  # To reset y-range
    top_reference_ax.set_xlim(0.1, 1.1)
    # Remove unnecessary tick marks
    for axis in axes.T[1:].flatten():
        axis.set_yticklabels([])
    # Add legends
    top_ax.legend(bbox_to_anchor=(1, 1.1))
    handles, labels = middle_ax.get_legend_handles_labels()
    middle_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
    # Add y labels
    top_reference_ax.set_ylabel(r'$\Delta m \, \left(PWV,\, z\right)$', fontsize=12)
    middle_reference_ax.set_ylabel(r'$\Delta m \, \left(z,\, PWV\right)$', fontsize=12)
    bottom_reference_ax.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{4 mm}$', fontsize=12)
    plt.tight_layout()
    return fig, axes 
# https://stackoverflow.com/questions/18311909/how-do-i-annotate-with-power-of-ten-formatting
[docs]def sci_notation(num: Numeric, decimal_digits: int = 1, precision: int = None, exponent: int = None) -> str:
    """Return a string representation of number in scientific notation."""
    if exponent is None:
        exponent = int(np.floor(np.log10(abs(num))))
    coeff = round(num / float(10 ** exponent), decimal_digits)
    if coeff == 1:
        return r"$10^{{{}}}$".format(exponent)
    if precision is None:
        precision = decimal_digits
    return r"${0:.{2}f}\cdot10^{{{1:d}}}$".format(coeff, exponent, precision) 
[docs]def plot_spectral_template(
        source: Union[str, sncosmo.Source],
        wave_arr: np.ndarray,
        z_arr: np.ndarray,
        pwv: np.ndarray,
        phase: Numeric = 0,
        resolution: Numeric = 2,
        figsize: Tuple[Numeric, Numeric] = (6, 4)
) -> Tuple[plt.Figure, np.array]:
    """Plot a spectral template with overlaid PWV and bandpass throughput curves
    Args:
        source: ``sncosmo`` source to use as spectral template
        wave_arr: The observer frame wavelengths to plot flux for in Angstroms
        z_arr: The redshifts to plot the template at
        pwv: The PWV to plot the transmission function for
        phase: The phase of the template to plot
        resolution: The resolution of the atmospheric model
        figsize: The size of the figure
    Returns:
        The matplotlib figure and an array of matplotlib axes
    """
    fig, (top_ax, bottom_ax) = plt.subplots(
        nrows=2, figsize=figsize, sharex='row',
        gridspec_kw={'height_ratios': [4, 1.75]})
    # Plot spectral template at given redshifts
    model = models.SNModel(source)
    flux_scale = 1e-13
    for i, z in enumerate(reversed(z_arr)):
        color = f'C{len(z_arr) - i - 1}'
        model.set(z=z)
        flux = model.flux(phase, wave_arr) / flux_scale
        top_ax.fill_between(wave_arr, flux, color=color, alpha=.8)
        top_ax.plot(wave_arr, flux, label=f'z = {z}', color=color, zorder=0)
    # Plot transmission function on twin axis at given wavelength resolution
    transmission = v1_transmission(pwv, res=resolution)
    twin_axis = top_ax.twinx()
    twin_axis.plot(transmission.index, transmission, alpha=0.75, color='grey')
    # Plot the band passes
    for b in 'rizy':
        band = sncosmo.get_bandpass(f'lsst_hardware_{b}')
        bottom_ax.plot(band.wave, band.trans, label=f'{b} Band')
    # Format top axis
    top_ax.set_ylim(0, 5)
    top_ax.set_xlim(min(wave_arr), max(wave_arr))
    top_ax.set_ylabel(f'Flux')
    top_ax.legend(loc='lower left', framealpha=1)
    # Format twin axis
    twin_axis.set_ylim(0, 1)
    twin_axis.set_ylabel('Transmission', rotation=-90, labelpad=12)
    plt.tight_layout()
    # Format bottom axis
    bottom_ax.set_ylim(0, 1)
    bottom_ax.set_xlabel(r'Wavelength $\AA$')
    bottom_ax.xaxis.set_minor_locator(MultipleLocator(500))
    bottom_ax.set_xticks(np.arange(4000, 11001, 2000))
    bottom_ax.legend(loc='lower left', framealpha=1)
    plt.subplots_adjust(hspace=0)
    return fig, np.array([top_ax, bottom_ax]) 
[docs]def plot_spectrum(
        wave: np.array,
        flux: np.array,
        figsize: Tuple[Numeric, Numeric] = (9, 3),
        hardware_only=False
) -> Tuple[plt.figure, plt.Axes]:
    """Plot a spectrum over the per-filter LSST hardware throughput
    Args:
        wave: Spectrum wavelengths in Angstroms
        flux: Flux of the spectrum in arbitrary units
        figsize: Size of the figure
        hardware_only: Only include hardware contributions in the plotted filters
    Returns:
        The matplotlib figure and axis
    """
    fig, axis = plt.subplots(figsize=figsize)
    axis.set_ylabel('Object Flux')
    axis.set_xlim(min(wave), max(wave))
    axis.set_xlabel(r'Wavelength ($\AA$)')
    twin_ax = axis.twinx()
    twin_ax.set_ylim(0, 1)
    twin_ax.set_ylabel('Bandpass Transmission', rotation=270, labelpad=15)
    prefix = 'lsst_hardware_' if hardware_only else 'lsst_total_'
    for filter_abbrev in 'ugrizy':
        bandpass = sncosmo.get_bandpass(f'{prefix}{filter_abbrev}')
        twin_ax.fill_between(wave, bandpass(wave), alpha=.3, label=filter_abbrev)
        twin_ax.plot(wave, bandpass(wave))
    axis.plot(wave, flux, color='k')
    return fig, axis 
# noinspection PyUnboundLocalVariable
[docs]def plot_magnitude(
        mags: Dict[str, np.ndarray], pwv: np.ndarray, z: np.ndarray, figsize: Tuple[Numeric, Numeric] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
    """Multi-panel plot showing magnitudes in different columns vs PWV and redshift in different rows
    Args:
        mags: Simulated magnitude values for each band
        pwv: Array of PWV values
        z: Array of redshift values
        figsize: Size of the figure
    Returns:
        The matplotlib figure and axis
    """
    fig, axes = plt.subplots(2, len(mags), figsize=figsize, sharey='row')
    for (band, mag_arr), (top_ax, bottom_ax) in zip(mags.items(), axes.T):
        top_ax.set_title(band)
        top_ax.set_xlabel('Redshift')
        _multi_line_plot(z, mag_arr, pwv, top_ax, label='{:g} mm')
        bottom_ax.set_xlabel('PWV')
        _multi_line_plot(pwv, mag_arr.T, z, bottom_ax, label='z = {:.2f}')
    axes[0][0].set_ylabel('Magnitude')
    axes[1][0].set_ylabel('Magnitude')
    # Add legends
    top_ax.legend(bbox_to_anchor=(1, 1.1))
    handles, labels = bottom_ax.get_legend_handles_labels()
    bottom_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
    plt.tight_layout()
    return fig, axes 
[docs]def plot_fitted_params(
        fitted_params: Dict[str, np.ndarray], pwv_arr: np.ndarray, z_arr: np.ndarray, bands: List[str]
) -> Tuple[plt.Figure, np.array]:
    """Multi-panel plot showing subplots for each salt2 parameter vs redshift.
    Multiple lines included for different PWV values.
    Args:
        fitted_params: Dictionary with fitted parameters in each band
        pwv_arr: PWV value used for each supernova fit
        z_arr: Redshift value used for each supernova fit
        bands: Bands to include in the plot. Must be keys of ``fitted_params``
    Returns:
        The matplotlib figure and an array of matplotlib axes
    """
    # Parse the fitted parameters for easier plotting
    model = sncosmo.Model('salt2-extended')
    params_dict = {
        param: fitted_params[bands[0]][..., i] for
        i, param in enumerate(model.param_names)
    }
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    for axis, (param, param_vals) in zip(axes.flatten(), params_dict.items()):
        if param == 'x0':
            param_vals = -2.5 * np.log10(param_vals)
            param = r'-2.5 log$_{10}$(x$_{0}$)'
        _multi_line_plot(z_arr, param_vals, pwv_arr, axis, label='z = {:g}')
        axis.set_xlabel('Redshift')
        axis.set_ylabel(param)
    correction_factor = const.betoule_alpha * params_dict['x1'] - const.betoule_beta * params_dict['c']
    _multi_line_plot(z_arr, correction_factor, pwv_arr, axes[-1][-1], label='PWV = {:g} mm')
    label = f'{const.betoule_alpha} * $x_1$ - {const.betoule_beta} * $c$'
    axes[-1][-1].set_ylabel(label)
    axes[-1][-1].legend(bbox_to_anchor=(1, 1.1))
    plt.tight_layout()
    return fig, axes 
# noinspection PyUnboundLocalVariable
[docs]def plot_delta_colors(
        pwv_arr: np.ndarray, z_arr: np.ndarray, mag_dict: Dict[str, np.ndarray],
        colors: List[Tuple[str, str]], ref_pwv: Numeric = 0
) -> None:
    """Shows the change in SN color as a function of redshift with each SN color coded by PWV
    Args:
        pwv_arr: Array of PWV values
        z_arr: Array of redshift values
        mag_dict: Dictionary with magnitudes for each band
        colors: Band combinations to plot colors for
        ref_pwv: Plot values relative to given reference PWV
    """
    num_cols = len(colors)
    fig, axes = plt.subplots(ncols=num_cols, figsize=(4 * num_cols, 4), sharey='col')
    if num_cols == 1:
        axes = [axes]
    ref_idx = list(pwv_arr).index(ref_pwv)
    for axis, (band1, band2) in zip(axes, colors):
        color = mag_dict[band1] - mag_dict[band2]
        delta_color = color - color[ref_idx]
        _multi_line_plot(z_arr, delta_color, pwv_arr, label='{} mm', axis=axis)
        axis.set_xlabel('Redshift', fontsize=14)
        axis.set_ylabel(fr'$\Delta$ ({band1[-1]} - {band2[-1]})', fontsize=14)
    axis.set_xlim(0, max(z_arr))
    axis.legend(bbox_to_anchor=(1, 1.1))
    plt.tight_layout() 
# noinspection PyUnusedLocal
[docs]def plot_delta_mu(
        mu: np.ndarray, pwv_arr: np.ndarray, z_arr: np.ndarray, cosmo: Cosmology = const.betoule_cosmo
) -> None:
    """Plot the variation in fitted distance modulus as a function of redshift and PWV
    Args:
        mu: Array of distance moduli
        pwv_arr: Array of PWV values
        z_arr: Array of redshift values
        cosmo: Astropy cosmology to compare results against
    """
    cosmo_mu = cosmo.distmod(z_arr).value
    delta_mu = mu - cosmo_mu
    fig, axes = plt.subplots(ncols=3, figsize=(9, 3))
    mu_ax, delta_mu_ax, relative_mu_ax = axes
    _multi_line_plot(z_arr, mu, pwv_arr, mu_ax)
    mu_ax.plot(z_arr, cosmo_mu, linestyle=':', color='k', label='Simulated')
    mu_ax.legend(framealpha=1)
    _multi_line_plot(z_arr, delta_mu, pwv_arr, delta_mu_ax)
    delta_mu_ax.axhline(0, linestyle=':', color='k', label='Simulated')
    delta_mu_ax.legend(framealpha=1)
    _multi_line_plot(z_arr, mu - mu[4], pwv_arr, relative_mu_ax, label='{:g} mm')
    relative_mu_ax.axhline(0, color='k', label=f'PWV={pwv_arr[4]}')
    relative_mu_ax.legend(framealpha=1, bbox_to_anchor=(1, 1.1))
    mu_ax.set_ylabel(r'$\mu$', fontsize=12)
    delta_mu_ax.set_ylabel(r'$\mu - \mu_{cosmo}$', fontsize=12)
    relative_mu_ax.set_ylabel(r'$\mu - \mu_{pwv_f}$', fontsize=12)
    for ax in axes:
        ax.set_xlabel('Redshift', fontsize=12)
    plt.tight_layout() 
[docs]def plot_year_pwv_vs_time(
        pwv_series: pd.Series, figsize: Tuple[Numeric, Numeric] = (8, 4), missing: Numeric = 1
) -> Tuple[plt.figure, plt.Axes]:
    """Plot PWV measurements taken over a single year as a function of time.
    Set ``missing=None`` to disable plotting of missing data windows.
    Args:
        pwv_series: Measured PWV index by datetime
        figsize: Size of the figure in inches
        missing: Highlight time ranges larger than given number of days with missing PWV
    Returns:
        The matplotlib figure and axis
    """
    pwv_series = pwv_series.sort_index()
    # Calculate rolling average of PWV time series
    rolling_mean_pwv = pwv_series.rolling(window='7D').mean()
    # In practice these dates vary by year so the values used here are approximate
    year = pwv_series.index[0].year
    mar_equinox = datetime(year, 3, 20, tzinfo=utc)
    jun_equinox = datetime(year, 6, 21, tzinfo=utc)
    sep_equinox = datetime(year, 9, 22, tzinfo=utc)
    dec_equinox = datetime(year, 12, 21, tzinfo=utc)
    # Separate data based on season
    summer_pwv = pwv_series[(pwv_series.index < mar_equinox) | (pwv_series.index > dec_equinox)]
    fall_pwv = pwv_series[(pwv_series.index > mar_equinox) & (pwv_series.index < jun_equinox)]
    winter_pwv = pwv_series[(pwv_series.index > jun_equinox) & (pwv_series.index < sep_equinox)]
    spring_pwv = pwv_series[(pwv_series.index > sep_equinox) & (pwv_series.index < dec_equinox)]
    print(f'Summer Average: {summer_pwv.mean(): .2f} +\\- {summer_pwv.std(): .2f} mm')
    print(f'Fall Average:  {fall_pwv.mean(): .2f} +\\- {fall_pwv.std(): .2f} mm')
    print(f'Winter Average: {winter_pwv.mean(): .2f} +\\- {winter_pwv.std(): .2f} mm')
    print(f'Spring Average: {spring_pwv.mean(): .2f} +\\- {spring_pwv.std(): .2f} mm')
    fig, axis = plt.subplots(figsize=figsize)
    axis.set_ylabel('PWV (mm)')
    axis.set_xlabel('Time of Year')
    axis.set_ylim(0, 20)
    ylow, yhigh = axis.get_ylim()
    # Plot windows of missing pwv
    if missing:
        missing_interval = timedelta(days=missing)
        year_start = datetime(year, 1, 1, tzinfo=utc)
        year_end = datetime(year, 12, 31, 23, 59, 59, tzinfo=utc)
        delta_t = pwv_series.index[1:] - pwv_series.index[:-1]
        start_indices = np.where(delta_t > missing_interval)[0]
        for index in start_indices:
            start_time = pwv_series.index[index]
            end_time = pwv_series.index[index + 1]
            axis.fill_between(
                x=[start_time, end_time], y1=[ylow, ylow], y2=[yhigh, yhigh],
                color='lightgrey', alpha=0.5, zorder=0)
        if (pwv_series.index[0] - year_start) > missing_interval:
            axis.fill_between(
                x=[year_start, pwv_series.index[0]],
                y1=[ylow, ylow], y2=[yhigh, yhigh],
                color='lightgrey', alpha=0.5, zorder=0)
        if (year_end - pwv_series.index[-1]) > missing_interval:
            axis.fill_between(
                x=[pwv_series.index[-1], year_end],
                y1=[ylow, ylow], y2=[yhigh, yhigh],
                color='lightgrey', alpha=0.5, zorder=0)
    # Plot measured PWV
    for equinox_date in (mar_equinox, jun_equinox, sep_equinox, dec_equinox):
        axis.axvline(equinox_date, linestyle='--', color='k', zorder=1)
    # Plot rolling average
    axis.scatter(pwv_series.index, pwv_series, s=1, alpha=.2, zorder=2)
    axis.plot(rolling_mean_pwv.index, rolling_mean_pwv, color='C1', label='Rolling Avg.', zorder=3, linewidth=2)
    # Plot seasonal average
    # winter is plotted separately because it spans the new year
    if not winter_pwv.empty:
        winter_avg = winter_pwv.mean()
        winter_std = winter_pwv.std()
        winter_subset = pwv_series[pwv_series.index < mar_equinox]
        winter_x = winter_subset.index.max() - (winter_subset.index.max() - winter_subset.index.min()) / 2
        axis.errorbar([winter_x], [winter_avg], yerr=[winter_std], color='k', zorder=4, linewidth=2, capsize=10,
                      capthick=2)
        axis.scatter([winter_x], [winter_avg], color='k', s=100, marker='+', zorder=4, label='Seasonal Avg.')
    for season in (spring_pwv, summer_pwv, fall_pwv):
        if season.empty:
            continue
        avg = season.mean()
        std = season.std()
        x = season.index.max() - (season.index.max() - season.index.min()) / 2
        axis.errorbar([x], [avg], yerr=[std], color='k', zorder=4, linewidth=2, capsize=10, capthick=2)
        axis.scatter([x], [avg], color='k', s=100, marker='+', zorder=4)
    # Format x labels to be three letter month abbreviations
    locator = mdates.MonthLocator()
    formatter = mdates.DateFormatter('%b')
    axis.xaxis.set_major_locator(locator)
    axis.xaxis.set_major_formatter(formatter)
    axis.legend(framealpha=1)
    axis.twinx().set_ylim(axis.get_ylim())
    return fig, axis 
# Ignore arguments with uppercase letters
# noinspection PyPep8Naming
[docs]def plot_cosmology_fit(
        data: pd.DataFrame, abs_mag: Numeric, H0: Numeric, Om0: Numeric, w0: Numeric, alpha: Numeric, beta: Numeric
) -> Tuple[plt.figure, np.ndarray, np.ndarray]:
    """Plot a cosmological fit to a set of supernova data.
    Args:
        data: Results from the snat_sim fitting pipeline
        abs_mag: Intrinsic absolute magnitude of SNe Ia
        H0: Fitted Hubble constant at z = 0 in [km/sec/Mpc]
        Om0: Omega matter density in units of the critical density at z=0
        w0: Dark energy equation of state
        alpha: Fitted nuisance parameter for supernova stretch correction
        beta: Fitted nuisance parameter for supernova color correction
    Returns:
        The matplotlib figure, fitted distance modulus, and tabulated residuals
    """
    data = data.sort_values('z')
    fitted_mu = FlatwCDM(H0=H0, Om0=Om0, w0=w0).distmod(data.z).value
    measured_mu = data.snat_sim.calc_distmod(abs_mag) + alpha * data.x1 - beta * data.c
    residuals = measured_mu - fitted_mu
    fig, (top_ax, bottom_ax) = plt.subplots(2, sharex='col', gridspec_kw={'height_ratios': [2, 1]})
    top_ax.errorbar(data.z, measured_mu, yerr=data.mb_err, linestyle='')
    top_ax.scatter(data.z, measured_mu, s=1)
    top_ax.plot(data.z, fitted_mu, color='k', alpha=.75)
    bottom_ax.axhline(0, color='k', alpha=.75, linestyle='--')
    bottom_ax.errorbar(data.z, residuals, yerr=data.mb_err, linestyle='')
    bottom_ax.scatter(data.z, residuals, s=1)
    # Style the plot
    top_ax.set_ylabel(r'$\mu = m^*_B - M_B + \alpha x_1 - \beta c$')
    bottom_ax.set_ylabel('Residuals')
    bottom_ax_lim = max(np.abs(bottom_ax.get_ylim()))
    bottom_ax.set_ylim(-bottom_ax_lim, bottom_ax_lim)
    bottom_ax.set_xlim(xmin=0)
    fig.subplots_adjust(hspace=0.1)
    return fig, fitted_mu, residuals 
[docs]def plot_residuals_on_sky(
        ra: np.array,
        dec: np.array,
        residual: np.array,
        cmap: str = 'coolwarm',
        figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
    """Plot hubble residuals as a function of supernova coordinates.
    Args:
        ra: Right Ascension for each supernova
        dec: Declination of each supernova
        residual: Hubble residual for each supernova
        cmap: Name of the matplotlib color map to use
        figsize: The size of the figure
    Returns:
        The matplotlib figure and axis
    """
    sn_coord = SkyCoord(ra, dec, unit=u.deg).galactic
    fig, axis = plt.subplots(figsize=figsize, subplot_kw={'projection': 'aitoff'})
    axis.grid(True)
    vlim = max(np.abs(residual))
    scat = axis.scatter(
        sn_coord.l.wrap_at('180d').radian, sn_coord.b.radian,
        c=residual, cmap=cmap, vmin=-vlim, vmax=vlim)
    plt.colorbar(scat).set_label('Hubble Residual', rotation=270, labelpad=15)
    return fig, axis 
[docs]def compare_prop_effects(
        pwv_data: pd.Series,
        static: models.StaticPWVTrans,
        seasonal: models.SeasonalPWVTrans,
        variable: models.VariablePWVTrans,
        figsize: Tuple[float, float] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
    """Compare the Zenith PWV assumed by different propagation effects
    Args:
        pwv_data: Series with PWV values and a Datetime index
        static: Static propagation effect
        seasonal: Seasonal Propagation effect
        variable: Variable Propagation effect
        figsize: The size of the figure
    Returns:
        The matplotlib figure and axis
    """
    variable = copy(variable)
    # Disable airmass scaling so we get values at zenith
    # noinspection PyProtectedMember
    variable._pwv_model.calc_airmass = lambda *args, **kwargs: 1
    pwv_data = pwv_data.sort_index()
    x_vals = np.arange(pwv_data.index[0], pwv_data.index[-1], timedelta(days=1)).astype(datetime)
    mjd_vals = Time(x_vals).mjd
    plt.figure(figsize=figsize)
    plt.scatter(pwv_data.index, pwv_data.values, s=2, alpha=.1, color='grey')
    plt.plot(x_vals, variable.assumed_pwv(mjd_vals), label='Variable')
    plt.plot(x_vals, seasonal.assumed_pwv(mjd_vals), label='Seasonal', linewidth=2.5)
    plt.plot([x_vals[0], x_vals[-1]], [static['pwv'], static['pwv']], label='Static', linewidth=2.5)
    plt.legend(loc='upper right', framealpha=1)
    plt.ylabel('PWV (mm)')
    plt.xlabel('Date (UTC)')
    plt.ylim(0, 25)
    plt.legend()
    return plt.gcf(), plt.gca() 
[docs]def plot_transmission_variation(
        pwv1: float, pwv2: float,
        wave_min: float = 6500,
        wave_max: float = 10000,
        resolution: int = 9,
        figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
    """Compare the atmospheric transmission function for two PWV concentrations
    Args:
        pwv1: The first PWV concentration to plot the transmission for
        pwv2: The second PWV concentration to plot the transmission for
        wave_min: Minimum wavelength to plot
        wave_max: Maximum wavelength to plot
        resolution: Bin the atmospheric transmission function to a lower transmission
        figsize: The size of the figure
    Returns:
        The matplotlib figure and axis
    """
    low_pwv = min(pwv1, pwv2)
    high_pwv = max(pwv1, pwv2)
    wave = np.arange(wave_min, wave_max)
    low_transmission = v1_transmission(pwv=low_pwv, wave=wave, res=resolution)
    high_transmission = v1_transmission(pwv=high_pwv, wave=wave, res=resolution)
    fig, axis = plt.subplots(figsize=figsize)
    axis.plot(wave, low_transmission, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
    axis.fill_between(wave, low_transmission, high_transmission, label=f'PWV = {high_pwv} mm', alpha=.75)
    axis.set_title(r'Change in atmospheric PWV transmission flux due to $\Delta$PWV')
    axis.set_xlabel(r'Wavelength ($\AA$)')
    axis.set_ylabel('PWV Transmission')
    axis.set_xlim(wave_min, wave_max)
    axis.set_ylim(.5, 1)
    axis.legend(framealpha=1)
    return fig, axis 
[docs]def plot_flux_variation(
        pwv1: float,
        pwv2: float,
        z: float = .55,
        wave_min: float = 6500,
        wave_max: float = 10000,
        resolution: int = 9,
        figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
    """Compare the PWV absorbed flux of a SN IA for two PWV concentrations
    Args:
        pwv1: The first PWV concentration to plot the flux for
        pwv2: The second PWV concentration to plot the flux for
        z: The redshift of the SN Ia
        wave_min: Minimum wavelength to plot
        wave_max: Maximum wavelength to plot
        resolution: Bin the atmospheric transmission function to a lower transmission
        figsize: The size of the figure
    Returns:
        The matplotlib figure and axis
    """
    low_pwv = min(pwv1, pwv2)
    high_pwv = max(pwv1, pwv2)
    model = models.SNModel('salt2-extended')
    model.add_effect(models.StaticPWVTrans(transmission_res=resolution), '', 'obs')
    wave = np.arange(wave_min, wave_max)
    model.set(z=z, pwv=low_pwv)
    low_pwv_flux = model.flux(0, wave)
    model.set(pwv=high_pwv)
    high_pwv_flux = model.flux(0, wave)
    fig, axis = plt.subplots(figsize=figsize)
    axis.plot(wave, low_pwv_flux, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
    axis.fill_between(wave, low_pwv_flux, high_pwv_flux, label=f'PWV = {high_pwv} mm', alpha=.75)
    axis.set_title(r'Change in SN Ia flux due to $\Delta$PWV')
    axis.set_xlabel(r'Wavelength ($\AA$)')
    axis.set_ylabel('SN Ia Flux')
    axis.set_xlim(wave_min, wave_max)
    axis.legend(framealpha=1)
    return fig, axis 
[docs]def plot_delta_sn_flux(
        pwv: float = 4, wave_min: float = 6500, wave_max: float = 10000, figsize: Tuple[Numeric, Numeric] = (8, 6)
) -> Tuple[plt.figure, np.array]:
    """Plot the change to spectroscopic SN Ia flux over wavelength and redshift
    An imshow style plot with the change in flux along the color axis.
    Args:
        pwv: The PWV concentration to use when determining the change in flux
        wave_min: Minimum wavelength to plot
        wave_max: Maximum wavelength to plot
        figsize: The size of the figure
    Returns:
        The matplotlib figure and and array of matplotlib axes
    """
    sn_model = models.SNModel('salt2-extended')
    sn_model.add_effect(models.StaticPWVTrans(transmission_res=10), '', 'obs')
    transmission_model = PWVTransmissionModel(resolution=5)
    delta_flux = []
    wave = np.arange(wave_min, wave_max)
    for z in np.arange(0.0001, 1.01, .1):
        sn_model.set(z=z, pwv=pwv)
        flux = sn_model.flux(0, wave)
        sn_model.set(pwv=0)
        flux_model = sn_model.flux(0, wave)
        delta_flux.append(flux - flux_model)
    fig, (top_axis, middle_axis, bottom_ax) = plt.subplots(
        nrows=3, figsize=figsize, sharex='col',
        gridspec_kw={'height_ratios': [1.75, 4, 1.75]})
    wave_arr = np.arange(wave_min, wave_max)
    trans = transmission_model.calc_transmission(pwv, wave_arr)
    top_axis.plot(wave_arr, 100 * trans)
    middle_axis.imshow(
        delta_flux,
        origin='lower',
        extent=[wave_min, wave_max, 0, 1],
        aspect='auto',
        cmap='Blues_r')
    # Plot the band passes
    for b in 'rizy':
        band = sncosmo.get_bandpass(f'lsst_total_{b}')
        bottom_ax.plot(band.wave, band.trans, label=f'{b}')
    # Format each axis
    top_axis.set_title(f'Change in spectral SN Ia flux due to PWV ({pwv} mm)')
    top_axis.set_ylabel('Transmission (%)')
    middle_axis.set_ylabel('Redshift (z)')
    bottom_ax.set_ylim(0, 1)
    bottom_ax.set_xlim(wave_min, wave_max)
    bottom_ax.set_xlabel(r'Wavelength $\AA$')
    bottom_ax.set_ylabel('Filters')
    bottom_ax.xaxis.set_minor_locator(MultipleLocator(250))
    bottom_ax.set_xticks(np.arange(wave_min, wave_max + 1, 500))
    bottom_ax.set_yticks([0, .25, .5, .75])
    bottom_ax.legend(loc='lower left', framealpha=1)
    plt.subplots_adjust(hspace=0)
    return fig, np.array([middle_axis, bottom_ax])