"""The ``plotting`` module provides functions for generating plots of
analysis results that are visually consistent and easily reproducible.
Plotting Function Summaries
---------------------------
.. autosummary::
:nosignatures:
plot_cosmology_fit
plot_delta_colors
plot_delta_mag_vs_pwv
plot_delta_mag_vs_z
plot_delta_mu
plot_derivative_mag_vs_z
plot_fitted_params
plot_magnitude
plot_pwv_mag_effects
plot_residuals_on_sky
plot_spectral_template
plot_year_pwv_vs_time
compare_prop_effects
plot_transmission_variation
plot_flux_variation
plot_delta_sn_flux
Module Docs
-----------
"""
from copy import copy
from datetime import datetime, timedelta
from typing import *
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import sncosmo
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.cosmology import FlatwCDM
from astropy.cosmology.core import Cosmology
from astropy.time import Time
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from pwv_kpno.defaults import v1_transmission
from pytz import utc
from . import constants as const
from . import models
from .models import PWVTransmissionModel
from .types import Numeric
def _multi_line_plot(
x_arr: np.ndarray, y_arr: np.ndarray, z_arr: np.ndarray, axis: plt.Axes, label: Optional[str] = None
) -> None:
"""Plot a 2d y array vs a 1d x array.
Lines are color coded according to values of the 2d ``z_arr`` array
Args:
x_arr: A 1d array with x-Values
y_arr: A 2d array with y values for multiple lines
z_arr: A 2d array with z (color) values for x,y pair
axis: Axis to plot on
label: Optional label to format with ``z_arr`` values
"""
# noinspection PyUnresolvedReferences
colors = plt.cm.viridis(np.linspace(0, 1, len(z_arr)))
for z, y, color in zip(z_arr, y_arr, colors):
if label:
axis.plot(x_arr, y, label=label.format(z), c=color)
else:
axis.plot(x_arr, y, c=color)
axis.set_xlim(x_arr[0], x_arr[-1])
[docs]def plot_delta_mag_vs_z(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag_arr: np.ndarray,
axis: Optional[plt.axis] = None,
label: Optional[str] = None
) -> None:
"""Single panel, multi-line plot of change in magnitude vs redshift color coded by PWV
Args:
pwv_arr: Array of PWV values
z_arr: Array of redshift values
delta_mag_arr: Array of delta mag values
axis: Optionally plot on a given axis
label: Optional label to format with PWV
"""
if axis is None:
axis = plt.gca()
_multi_line_plot(z_arr, delta_mag_arr, pwv_arr, axis, label)
axis.set_xlabel('Redshift', fontsize=20)
axis.set_xlim(min(z_arr), max(z_arr))
axis.set_ylabel(r'$\Delta m$', fontsize=20)
[docs]def plot_delta_mag_vs_pwv(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag_arr: np.ndarray,
axis: Optional[plt.axis] = None,
label: Optional[str] = None
) -> None:
"""Single panel, multi-line plot for change in magnitude vs PWV color coded by redshift
Args:
pwv_arr: Array of PWV values
z_arr: Array of redshift values
delta_mag_arr: Array of delta mag values
axis: Optionally plot on a given axis
label: Optional label to format with redshift
"""
if axis is None:
axis = plt.gca()
_multi_line_plot(pwv_arr, delta_mag_arr.T, z_arr, axis, label)
axis.set_xlabel('PWV', fontsize=20)
axis.set_xlim(min(pwv_arr), max(pwv_arr))
axis.set_ylabel(r'$\Delta m$', fontsize=20)
# noinspection PyUnusedLocal
[docs]def plot_derivative_mag_vs_z(
pwv_arr: np.ndarray, z_arr: np.ndarray, slope_arr: np.ndarray, axis: Optional[plt.axis] = None
) -> None:
"""Single panel, multi-line plot of slope in delta magnitude vs redshift color coded by PWV
Args:
pwv_arr: Array of PWV values
z_arr: Array of redshift values
slope_arr: Slope of delta mag at reference PWV
axis: Optionally plot on a given axis
"""
if axis is None:
axis = plt.gca()
axis.plot(z_arr, slope_arr)
axis.set_xlabel('Redshift', fontsize=20)
axis.set_xlim(min(z_arr), max(z_arr))
axis.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{PWV = 4 mm}$', fontsize=20)
# noinspection PyUnboundLocalVariable
[docs]def plot_pwv_mag_effects(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag: dict,
slopes: np.ndarray,
bands: List[str],
figsize: Tuple[Numeric, Numeric] = (10, 8)
) -> Tuple[plt.Figure, plt.Axes]:
"""Multi panel plot with a column for each band and rows for the change in magnitude vs pwv and redshift parameters
``delta_mag`` is expected to have band names as keys, and 2d arrays as
values. Each array should represent the change in magnitude for each
given PWV and redshift
Args:
pwv_arr: PWV values used in the calculation
z_arr: Redshift values used in the calculation
delta_mag: Dictionary with delta mag for each band
slopes: Slope in delta_mag for each redshift
bands: Order of bands to plot
figsize: The size of the figure
Returns:
The matplotlib figure and axis
"""
fig, axes = plt.subplots(3, len(delta_mag), figsize=figsize)
top_reference_ax = axes[0, 0]
middle_reference_ax = axes[1, 0]
bottom_reference_ax = axes[2, 0]
# Plot data
for band, axes_column in zip(bands, axes.T):
top_ax, middle_ax, bottom_ax = axes_column
# First row
plot_delta_mag_vs_z(pwv_arr, z_arr, delta_mag[band], top_ax, label='{:g} mm')
top_ax.axhline(0, linestyle='--', color='k', label='4 mm')
top_ax.set_title(f'{band[-1]}-band')
top_ax.set_xlabel('Redshift', fontsize=12)
top_ax.set_ylabel('')
# Middle row
plot_delta_mag_vs_pwv(pwv_arr, z_arr, delta_mag[band], middle_ax, label='z = {:g}')
top_ax.axvline(4, linestyle='--', color='k')
middle_ax.set_xlabel('PWV', fontsize=12)
middle_ax.set_ylabel('')
# Bottom row
plot_derivative_mag_vs_z(pwv_arr, z_arr, slopes[band], bottom_ax)
bottom_ax.set_xlabel('Redshift', fontsize=12)
bottom_ax.set_ylabel('')
# Share axes
top_ax.get_shared_y_axes().join(top_ax, top_reference_ax)
middle_ax.get_shared_y_axes().join(middle_ax, middle_reference_ax)
bottom_ax.get_shared_y_axes().join(bottom_ax, bottom_reference_ax)
top_ax.get_shared_x_axes().join(top_ax, top_reference_ax)
bottom_ax.get_shared_x_axes().join(bottom_ax, top_reference_ax)
top_reference_ax.autoscale() # To reset y-range
top_reference_ax.set_xlim(0.1, 1.1)
# Remove unnecessary tick marks
for axis in axes.T[1:].flatten():
axis.set_yticklabels([])
# Add legends
top_ax.legend(bbox_to_anchor=(1, 1.1))
handles, labels = middle_ax.get_legend_handles_labels()
middle_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
# Add y labels
top_reference_ax.set_ylabel(r'$\Delta m \, \left(PWV,\, z\right)$', fontsize=12)
middle_reference_ax.set_ylabel(r'$\Delta m \, \left(z,\, PWV\right)$', fontsize=12)
bottom_reference_ax.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{4 mm}$', fontsize=12)
plt.tight_layout()
return fig, axes
# https://stackoverflow.com/questions/18311909/how-do-i-annotate-with-power-of-ten-formatting
[docs]def sci_notation(num: Numeric, decimal_digits: int = 1, precision: int = None, exponent: int = None) -> str:
"""Return a string representation of number in scientific notation."""
if exponent is None:
exponent = int(np.floor(np.log10(abs(num))))
coeff = round(num / float(10 ** exponent), decimal_digits)
if coeff == 1:
return r"$10^{{{}}}$".format(exponent)
if precision is None:
precision = decimal_digits
return r"${0:.{2}f}\cdot10^{{{1:d}}}$".format(coeff, exponent, precision)
[docs]def plot_spectral_template(
source: Union[str, sncosmo.Source],
wave_arr: np.ndarray,
z_arr: np.ndarray,
pwv: np.ndarray,
phase: Numeric = 0,
resolution: Numeric = 2,
figsize: Tuple[Numeric, Numeric] = (6, 4)
) -> Tuple[plt.Figure, np.array]:
"""Plot a spectral template with overlaid PWV and bandpass throughput curves
Args:
source: ``sncosmo`` source to use as spectral template
wave_arr: The observer frame wavelengths to plot flux for in Angstroms
z_arr: The redshifts to plot the template at
pwv: The PWV to plot the transmission function for
phase: The phase of the template to plot
resolution: The resolution of the atmospheric model
figsize: The size of the figure
Returns:
The matplotlib figure and an array of matplotlib axes
"""
fig, (top_ax, bottom_ax) = plt.subplots(
nrows=2, figsize=figsize, sharex='row',
gridspec_kw={'height_ratios': [4, 1.75]})
# Plot spectral template at given redshifts
model = models.SNModel(source)
flux_scale = 1e-13
for i, z in enumerate(reversed(z_arr)):
color = f'C{len(z_arr) - i - 1}'
model.set(z=z)
flux = model.flux(phase, wave_arr) / flux_scale
top_ax.fill_between(wave_arr, flux, color=color, alpha=.8)
top_ax.plot(wave_arr, flux, label=f'z = {z}', color=color, zorder=0)
# Plot transmission function on twin axis at given wavelength resolution
transmission = v1_transmission(pwv, res=resolution)
twin_axis = top_ax.twinx()
twin_axis.plot(transmission.index, transmission, alpha=0.75, color='grey')
# Plot the band passes
for b in 'rizy':
band = sncosmo.get_bandpass(f'lsst_hardware_{b}')
bottom_ax.plot(band.wave, band.trans, label=f'{b} Band')
# Format top axis
top_ax.set_ylim(0, 5)
top_ax.set_xlim(min(wave_arr), max(wave_arr))
top_ax.set_ylabel(f'Flux')
top_ax.legend(loc='lower left', framealpha=1)
# Format twin axis
twin_axis.set_ylim(0, 1)
twin_axis.set_ylabel('Transmission', rotation=-90, labelpad=12)
plt.tight_layout()
# Format bottom axis
bottom_ax.set_ylim(0, 1)
bottom_ax.set_xlabel(r'Wavelength $\AA$')
bottom_ax.xaxis.set_minor_locator(MultipleLocator(500))
bottom_ax.set_xticks(np.arange(4000, 11001, 2000))
bottom_ax.legend(loc='lower left', framealpha=1)
plt.subplots_adjust(hspace=0)
return fig, np.array([top_ax, bottom_ax])
[docs]def plot_spectrum(
wave: np.array,
flux: np.array,
figsize: Tuple[Numeric, Numeric] = (9, 3),
hardware_only=False
) -> Tuple[plt.figure, plt.Axes]:
"""Plot a spectrum over the per-filter LSST hardware throughput
Args:
wave: Spectrum wavelengths in Angstroms
flux: Flux of the spectrum in arbitrary units
figsize: Size of the figure
hardware_only: Only include hardware contributions in the plotted filters
Returns:
The matplotlib figure and axis
"""
fig, axis = plt.subplots(figsize=figsize)
axis.set_ylabel('Object Flux')
axis.set_xlim(min(wave), max(wave))
axis.set_xlabel(r'Wavelength ($\AA$)')
twin_ax = axis.twinx()
twin_ax.set_ylim(0, 1)
twin_ax.set_ylabel('Bandpass Transmission', rotation=270, labelpad=15)
prefix = 'lsst_hardware_' if hardware_only else 'lsst_total_'
for filter_abbrev in 'ugrizy':
bandpass = sncosmo.get_bandpass(f'{prefix}{filter_abbrev}')
twin_ax.fill_between(wave, bandpass(wave), alpha=.3, label=filter_abbrev)
twin_ax.plot(wave, bandpass(wave))
axis.plot(wave, flux, color='k')
return fig, axis
# noinspection PyUnboundLocalVariable
[docs]def plot_magnitude(
mags: Dict[str, np.ndarray], pwv: np.ndarray, z: np.ndarray, figsize: Tuple[Numeric, Numeric] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
"""Multi-panel plot showing magnitudes in different columns vs PWV and redshift in different rows
Args:
mags: Simulated magnitude values for each band
pwv: Array of PWV values
z: Array of redshift values
figsize: Size of the figure
Returns:
The matplotlib figure and axis
"""
fig, axes = plt.subplots(2, len(mags), figsize=figsize, sharey='row')
for (band, mag_arr), (top_ax, bottom_ax) in zip(mags.items(), axes.T):
top_ax.set_title(band)
top_ax.set_xlabel('Redshift')
_multi_line_plot(z, mag_arr, pwv, top_ax, label='{:g} mm')
bottom_ax.set_xlabel('PWV')
_multi_line_plot(pwv, mag_arr.T, z, bottom_ax, label='z = {:.2f}')
axes[0][0].set_ylabel('Magnitude')
axes[1][0].set_ylabel('Magnitude')
# Add legends
top_ax.legend(bbox_to_anchor=(1, 1.1))
handles, labels = bottom_ax.get_legend_handles_labels()
bottom_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
plt.tight_layout()
return fig, axes
[docs]def plot_fitted_params(
fitted_params: Dict[str, np.ndarray], pwv_arr: np.ndarray, z_arr: np.ndarray, bands: List[str]
) -> Tuple[plt.Figure, np.array]:
"""Multi-panel plot showing subplots for each salt2 parameter vs redshift.
Multiple lines included for different PWV values.
Args:
fitted_params: Dictionary with fitted parameters in each band
pwv_arr: PWV value used for each supernova fit
z_arr: Redshift value used for each supernova fit
bands: Bands to include in the plot. Must be keys of ``fitted_params``
Returns:
The matplotlib figure and an array of matplotlib axes
"""
# Parse the fitted parameters for easier plotting
model = sncosmo.Model('salt2-extended')
params_dict = {
param: fitted_params[bands[0]][..., i] for
i, param in enumerate(model.param_names)
}
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
for axis, (param, param_vals) in zip(axes.flatten(), params_dict.items()):
if param == 'x0':
param_vals = -2.5 * np.log10(param_vals)
param = r'-2.5 log$_{10}$(x$_{0}$)'
_multi_line_plot(z_arr, param_vals, pwv_arr, axis, label='z = {:g}')
axis.set_xlabel('Redshift')
axis.set_ylabel(param)
correction_factor = const.betoule_alpha * params_dict['x1'] - const.betoule_beta * params_dict['c']
_multi_line_plot(z_arr, correction_factor, pwv_arr, axes[-1][-1], label='PWV = {:g} mm')
label = f'{const.betoule_alpha} * $x_1$ - {const.betoule_beta} * $c$'
axes[-1][-1].set_ylabel(label)
axes[-1][-1].legend(bbox_to_anchor=(1, 1.1))
plt.tight_layout()
return fig, axes
# noinspection PyUnboundLocalVariable
[docs]def plot_delta_colors(
pwv_arr: np.ndarray, z_arr: np.ndarray, mag_dict: Dict[str, np.ndarray],
colors: List[Tuple[str, str]], ref_pwv: Numeric = 0
) -> None:
"""Shows the change in SN color as a function of redshift with each SN color coded by PWV
Args:
pwv_arr: Array of PWV values
z_arr: Array of redshift values
mag_dict: Dictionary with magnitudes for each band
colors: Band combinations to plot colors for
ref_pwv: Plot values relative to given reference PWV
"""
num_cols = len(colors)
fig, axes = plt.subplots(ncols=num_cols, figsize=(4 * num_cols, 4), sharey='col')
if num_cols == 1:
axes = [axes]
ref_idx = list(pwv_arr).index(ref_pwv)
for axis, (band1, band2) in zip(axes, colors):
color = mag_dict[band1] - mag_dict[band2]
delta_color = color - color[ref_idx]
_multi_line_plot(z_arr, delta_color, pwv_arr, label='{} mm', axis=axis)
axis.set_xlabel('Redshift', fontsize=14)
axis.set_ylabel(fr'$\Delta$ ({band1[-1]} - {band2[-1]})', fontsize=14)
axis.set_xlim(0, max(z_arr))
axis.legend(bbox_to_anchor=(1, 1.1))
plt.tight_layout()
# noinspection PyUnusedLocal
[docs]def plot_delta_mu(
mu: np.ndarray, pwv_arr: np.ndarray, z_arr: np.ndarray, cosmo: Cosmology = const.betoule_cosmo
) -> None:
"""Plot the variation in fitted distance modulus as a function of redshift and PWV
Args:
mu: Array of distance moduli
pwv_arr: Array of PWV values
z_arr: Array of redshift values
cosmo: Astropy cosmology to compare results against
"""
cosmo_mu = cosmo.distmod(z_arr).value
delta_mu = mu - cosmo_mu
fig, axes = plt.subplots(ncols=3, figsize=(9, 3))
mu_ax, delta_mu_ax, relative_mu_ax = axes
_multi_line_plot(z_arr, mu, pwv_arr, mu_ax)
mu_ax.plot(z_arr, cosmo_mu, linestyle=':', color='k', label='Simulated')
mu_ax.legend(framealpha=1)
_multi_line_plot(z_arr, delta_mu, pwv_arr, delta_mu_ax)
delta_mu_ax.axhline(0, linestyle=':', color='k', label='Simulated')
delta_mu_ax.legend(framealpha=1)
_multi_line_plot(z_arr, mu - mu[4], pwv_arr, relative_mu_ax, label='{:g} mm')
relative_mu_ax.axhline(0, color='k', label=f'PWV={pwv_arr[4]}')
relative_mu_ax.legend(framealpha=1, bbox_to_anchor=(1, 1.1))
mu_ax.set_ylabel(r'$\mu$', fontsize=12)
delta_mu_ax.set_ylabel(r'$\mu - \mu_{cosmo}$', fontsize=12)
relative_mu_ax.set_ylabel(r'$\mu - \mu_{pwv_f}$', fontsize=12)
for ax in axes:
ax.set_xlabel('Redshift', fontsize=12)
plt.tight_layout()
[docs]def plot_year_pwv_vs_time(
pwv_series: pd.Series, figsize: Tuple[Numeric, Numeric] = (8, 4), missing: Numeric = 1
) -> Tuple[plt.figure, plt.Axes]:
"""Plot PWV measurements taken over a single year as a function of time.
Set ``missing=None`` to disable plotting of missing data windows.
Args:
pwv_series: Measured PWV index by datetime
figsize: Size of the figure in inches
missing: Highlight time ranges larger than given number of days with missing PWV
Returns:
The matplotlib figure and axis
"""
pwv_series = pwv_series.sort_index()
# Calculate rolling average of PWV time series
rolling_mean_pwv = pwv_series.rolling(window='7D').mean()
# In practice these dates vary by year so the values used here are approximate
year = pwv_series.index[0].year
mar_equinox = datetime(year, 3, 20, tzinfo=utc)
jun_equinox = datetime(year, 6, 21, tzinfo=utc)
sep_equinox = datetime(year, 9, 22, tzinfo=utc)
dec_equinox = datetime(year, 12, 21, tzinfo=utc)
# Separate data based on season
summer_pwv = pwv_series[(pwv_series.index < mar_equinox) | (pwv_series.index > dec_equinox)]
fall_pwv = pwv_series[(pwv_series.index > mar_equinox) & (pwv_series.index < jun_equinox)]
winter_pwv = pwv_series[(pwv_series.index > jun_equinox) & (pwv_series.index < sep_equinox)]
spring_pwv = pwv_series[(pwv_series.index > sep_equinox) & (pwv_series.index < dec_equinox)]
print(f'Summer Average: {summer_pwv.mean(): .2f} +\\- {summer_pwv.std(): .2f} mm')
print(f'Fall Average: {fall_pwv.mean(): .2f} +\\- {fall_pwv.std(): .2f} mm')
print(f'Winter Average: {winter_pwv.mean(): .2f} +\\- {winter_pwv.std(): .2f} mm')
print(f'Spring Average: {spring_pwv.mean(): .2f} +\\- {spring_pwv.std(): .2f} mm')
fig, axis = plt.subplots(figsize=figsize)
axis.set_ylabel('PWV (mm)')
axis.set_xlabel('Time of Year')
axis.set_ylim(0, 20)
ylow, yhigh = axis.get_ylim()
# Plot windows of missing pwv
if missing:
missing_interval = timedelta(days=missing)
year_start = datetime(year, 1, 1, tzinfo=utc)
year_end = datetime(year, 12, 31, 23, 59, 59, tzinfo=utc)
delta_t = pwv_series.index[1:] - pwv_series.index[:-1]
start_indices = np.where(delta_t > missing_interval)[0]
for index in start_indices:
start_time = pwv_series.index[index]
end_time = pwv_series.index[index + 1]
axis.fill_between(
x=[start_time, end_time], y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
if (pwv_series.index[0] - year_start) > missing_interval:
axis.fill_between(
x=[year_start, pwv_series.index[0]],
y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
if (year_end - pwv_series.index[-1]) > missing_interval:
axis.fill_between(
x=[pwv_series.index[-1], year_end],
y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
# Plot measured PWV
for equinox_date in (mar_equinox, jun_equinox, sep_equinox, dec_equinox):
axis.axvline(equinox_date, linestyle='--', color='k', zorder=1)
# Plot rolling average
axis.scatter(pwv_series.index, pwv_series, s=1, alpha=.2, zorder=2)
axis.plot(rolling_mean_pwv.index, rolling_mean_pwv, color='C1', label='Rolling Avg.', zorder=3, linewidth=2)
# Plot seasonal average
# winter is plotted separately because it spans the new year
if not winter_pwv.empty:
winter_avg = winter_pwv.mean()
winter_std = winter_pwv.std()
winter_subset = pwv_series[pwv_series.index < mar_equinox]
winter_x = winter_subset.index.max() - (winter_subset.index.max() - winter_subset.index.min()) / 2
axis.errorbar([winter_x], [winter_avg], yerr=[winter_std], color='k', zorder=4, linewidth=2, capsize=10,
capthick=2)
axis.scatter([winter_x], [winter_avg], color='k', s=100, marker='+', zorder=4, label='Seasonal Avg.')
for season in (spring_pwv, summer_pwv, fall_pwv):
if season.empty:
continue
avg = season.mean()
std = season.std()
x = season.index.max() - (season.index.max() - season.index.min()) / 2
axis.errorbar([x], [avg], yerr=[std], color='k', zorder=4, linewidth=2, capsize=10, capthick=2)
axis.scatter([x], [avg], color='k', s=100, marker='+', zorder=4)
# Format x labels to be three letter month abbreviations
locator = mdates.MonthLocator()
formatter = mdates.DateFormatter('%b')
axis.xaxis.set_major_locator(locator)
axis.xaxis.set_major_formatter(formatter)
axis.legend(framealpha=1)
axis.twinx().set_ylim(axis.get_ylim())
return fig, axis
# Ignore arguments with uppercase letters
# noinspection PyPep8Naming
[docs]def plot_cosmology_fit(
data: pd.DataFrame, abs_mag: Numeric, H0: Numeric, Om0: Numeric, w0: Numeric, alpha: Numeric, beta: Numeric
) -> Tuple[plt.figure, np.ndarray, np.ndarray]:
"""Plot a cosmological fit to a set of supernova data.
Args:
data: Results from the snat_sim fitting pipeline
abs_mag: Intrinsic absolute magnitude of SNe Ia
H0: Fitted Hubble constant at z = 0 in [km/sec/Mpc]
Om0: Omega matter density in units of the critical density at z=0
w0: Dark energy equation of state
alpha: Fitted nuisance parameter for supernova stretch correction
beta: Fitted nuisance parameter for supernova color correction
Returns:
The matplotlib figure, fitted distance modulus, and tabulated residuals
"""
data = data.sort_values('z')
fitted_mu = FlatwCDM(H0=H0, Om0=Om0, w0=w0).distmod(data.z).value
measured_mu = data.snat_sim.calc_distmod(abs_mag) + alpha * data.x1 - beta * data.c
residuals = measured_mu - fitted_mu
fig, (top_ax, bottom_ax) = plt.subplots(2, sharex='col', gridspec_kw={'height_ratios': [2, 1]})
top_ax.errorbar(data.z, measured_mu, yerr=data.mb_err, linestyle='')
top_ax.scatter(data.z, measured_mu, s=1)
top_ax.plot(data.z, fitted_mu, color='k', alpha=.75)
bottom_ax.axhline(0, color='k', alpha=.75, linestyle='--')
bottom_ax.errorbar(data.z, residuals, yerr=data.mb_err, linestyle='')
bottom_ax.scatter(data.z, residuals, s=1)
# Style the plot
top_ax.set_ylabel(r'$\mu = m^*_B - M_B + \alpha x_1 - \beta c$')
bottom_ax.set_ylabel('Residuals')
bottom_ax_lim = max(np.abs(bottom_ax.get_ylim()))
bottom_ax.set_ylim(-bottom_ax_lim, bottom_ax_lim)
bottom_ax.set_xlim(xmin=0)
fig.subplots_adjust(hspace=0.1)
return fig, fitted_mu, residuals
[docs]def plot_residuals_on_sky(
ra: np.array,
dec: np.array,
residual: np.array,
cmap: str = 'coolwarm',
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Plot hubble residuals as a function of supernova coordinates.
Args:
ra: Right Ascension for each supernova
dec: Declination of each supernova
residual: Hubble residual for each supernova
cmap: Name of the matplotlib color map to use
figsize: The size of the figure
Returns:
The matplotlib figure and axis
"""
sn_coord = SkyCoord(ra, dec, unit=u.deg).galactic
fig, axis = plt.subplots(figsize=figsize, subplot_kw={'projection': 'aitoff'})
axis.grid(True)
vlim = max(np.abs(residual))
scat = axis.scatter(
sn_coord.l.wrap_at('180d').radian, sn_coord.b.radian,
c=residual, cmap=cmap, vmin=-vlim, vmax=vlim)
plt.colorbar(scat).set_label('Hubble Residual', rotation=270, labelpad=15)
return fig, axis
[docs]def compare_prop_effects(
pwv_data: pd.Series,
static: models.StaticPWVTrans,
seasonal: models.SeasonalPWVTrans,
variable: models.VariablePWVTrans,
figsize: Tuple[float, float] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the Zenith PWV assumed by different propagation effects
Args:
pwv_data: Series with PWV values and a Datetime index
static: Static propagation effect
seasonal: Seasonal Propagation effect
variable: Variable Propagation effect
figsize: The size of the figure
Returns:
The matplotlib figure and axis
"""
variable = copy(variable)
# Disable airmass scaling so we get values at zenith
# noinspection PyProtectedMember
variable._pwv_model.calc_airmass = lambda *args, **kwargs: 1
pwv_data = pwv_data.sort_index()
x_vals = np.arange(pwv_data.index[0], pwv_data.index[-1], timedelta(days=1)).astype(datetime)
mjd_vals = Time(x_vals).mjd
plt.figure(figsize=figsize)
plt.scatter(pwv_data.index, pwv_data.values, s=2, alpha=.1, color='grey')
plt.plot(x_vals, variable.assumed_pwv(mjd_vals), label='Variable')
plt.plot(x_vals, seasonal.assumed_pwv(mjd_vals), label='Seasonal', linewidth=2.5)
plt.plot([x_vals[0], x_vals[-1]], [static['pwv'], static['pwv']], label='Static', linewidth=2.5)
plt.legend(loc='upper right', framealpha=1)
plt.ylabel('PWV (mm)')
plt.xlabel('Date (UTC)')
plt.ylim(0, 25)
plt.legend()
return plt.gcf(), plt.gca()
[docs]def plot_transmission_variation(
pwv1: float, pwv2: float,
wave_min: float = 6500,
wave_max: float = 10000,
resolution: int = 9,
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the atmospheric transmission function for two PWV concentrations
Args:
pwv1: The first PWV concentration to plot the transmission for
pwv2: The second PWV concentration to plot the transmission for
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
resolution: Bin the atmospheric transmission function to a lower transmission
figsize: The size of the figure
Returns:
The matplotlib figure and axis
"""
low_pwv = min(pwv1, pwv2)
high_pwv = max(pwv1, pwv2)
wave = np.arange(wave_min, wave_max)
low_transmission = v1_transmission(pwv=low_pwv, wave=wave, res=resolution)
high_transmission = v1_transmission(pwv=high_pwv, wave=wave, res=resolution)
fig, axis = plt.subplots(figsize=figsize)
axis.plot(wave, low_transmission, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
axis.fill_between(wave, low_transmission, high_transmission, label=f'PWV = {high_pwv} mm', alpha=.75)
axis.set_title(r'Change in atmospheric PWV transmission flux due to $\Delta$PWV')
axis.set_xlabel(r'Wavelength ($\AA$)')
axis.set_ylabel('PWV Transmission')
axis.set_xlim(wave_min, wave_max)
axis.set_ylim(.5, 1)
axis.legend(framealpha=1)
return fig, axis
[docs]def plot_flux_variation(
pwv1: float,
pwv2: float,
z: float = .55,
wave_min: float = 6500,
wave_max: float = 10000,
resolution: int = 9,
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the PWV absorbed flux of a SN IA for two PWV concentrations
Args:
pwv1: The first PWV concentration to plot the flux for
pwv2: The second PWV concentration to plot the flux for
z: The redshift of the SN Ia
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
resolution: Bin the atmospheric transmission function to a lower transmission
figsize: The size of the figure
Returns:
The matplotlib figure and axis
"""
low_pwv = min(pwv1, pwv2)
high_pwv = max(pwv1, pwv2)
model = models.SNModel('salt2-extended')
model.add_effect(models.StaticPWVTrans(transmission_res=resolution), '', 'obs')
wave = np.arange(wave_min, wave_max)
model.set(z=z, pwv=low_pwv)
low_pwv_flux = model.flux(0, wave)
model.set(pwv=high_pwv)
high_pwv_flux = model.flux(0, wave)
fig, axis = plt.subplots(figsize=figsize)
axis.plot(wave, low_pwv_flux, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
axis.fill_between(wave, low_pwv_flux, high_pwv_flux, label=f'PWV = {high_pwv} mm', alpha=.75)
axis.set_title(r'Change in SN Ia flux due to $\Delta$PWV')
axis.set_xlabel(r'Wavelength ($\AA$)')
axis.set_ylabel('SN Ia Flux')
axis.set_xlim(wave_min, wave_max)
axis.legend(framealpha=1)
return fig, axis
[docs]def plot_delta_sn_flux(
pwv: float = 4, wave_min: float = 6500, wave_max: float = 10000, figsize: Tuple[Numeric, Numeric] = (8, 6)
) -> Tuple[plt.figure, np.array]:
"""Plot the change to spectroscopic SN Ia flux over wavelength and redshift
An imshow style plot with the change in flux along the color axis.
Args:
pwv: The PWV concentration to use when determining the change in flux
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
figsize: The size of the figure
Returns:
The matplotlib figure and and array of matplotlib axes
"""
sn_model = models.SNModel('salt2-extended')
sn_model.add_effect(models.StaticPWVTrans(transmission_res=10), '', 'obs')
transmission_model = PWVTransmissionModel(resolution=5)
delta_flux = []
wave = np.arange(wave_min, wave_max)
for z in np.arange(0.0001, 1.01, .1):
sn_model.set(z=z, pwv=pwv)
flux = sn_model.flux(0, wave)
sn_model.set(pwv=0)
flux_model = sn_model.flux(0, wave)
delta_flux.append(flux - flux_model)
fig, (top_axis, middle_axis, bottom_ax) = plt.subplots(
nrows=3, figsize=figsize, sharex='col',
gridspec_kw={'height_ratios': [1.75, 4, 1.75]})
wave_arr = np.arange(wave_min, wave_max)
trans = transmission_model.calc_transmission(pwv, wave_arr)
top_axis.plot(wave_arr, 100 * trans)
middle_axis.imshow(
delta_flux,
origin='lower',
extent=[wave_min, wave_max, 0, 1],
aspect='auto',
cmap='Blues_r')
# Plot the band passes
for b in 'rizy':
band = sncosmo.get_bandpass(f'lsst_total_{b}')
bottom_ax.plot(band.wave, band.trans, label=f'{b}')
# Format each axis
top_axis.set_title(f'Change in spectral SN Ia flux due to PWV ({pwv} mm)')
top_axis.set_ylabel('Transmission (%)')
middle_axis.set_ylabel('Redshift (z)')
bottom_ax.set_ylim(0, 1)
bottom_ax.set_xlim(wave_min, wave_max)
bottom_ax.set_xlabel(r'Wavelength $\AA$')
bottom_ax.set_ylabel('Filters')
bottom_ax.xaxis.set_minor_locator(MultipleLocator(250))
bottom_ax.set_xticks(np.arange(wave_min, wave_max + 1, 500))
bottom_ax.set_yticks([0, .25, .5, .75])
bottom_ax.legend(loc='lower left', framealpha=1)
plt.subplots_adjust(hspace=0)
return fig, np.array([middle_axis, bottom_ax])