"""The ``plotting`` module provides functions for generating plots of
analysis results that are visually consistent and easily reproducible.
Plotting Function Summaries
Module Docs
from copy import copy
from datetime import datetime, timedelta
from typing import *
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import sncosmo
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.cosmology import FlatwCDM
from astropy.cosmology.core import Cosmology
from astropy.time import Time
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from pwv_kpno.defaults import v1_transmission
from pytz import utc
from . import constants as const
from . import models
from .models import PWVTransmissionModel
from .types import Numeric
def _multi_line_plot(
x_arr: np.ndarray, y_arr: np.ndarray, z_arr: np.ndarray, axis: plt.Axes, label: Optional[str] = None
) -> None:
"""Plot a 2d y array vs a 1d x array.
Lines are color coded according to values of the 2d ``z_arr`` array
x_arr: A 1d array with x-Values
y_arr: A 2d array with y values for multiple lines
z_arr: A 2d array with z (color) values for x,y pair
axis: Axis to plot on
label: Optional label to format with ``z_arr`` values
# noinspection PyUnresolvedReferences
colors = plt.cm.viridis(np.linspace(0, 1, len(z_arr)))
for z, y, color in zip(z_arr, y_arr, colors):
if label:
axis.plot(x_arr, y, label=label.format(z), c=color)
axis.plot(x_arr, y, c=color)
axis.set_xlim(x_arr[0], x_arr[-1])
[docs]def plot_delta_mag_vs_z(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag_arr: np.ndarray,
axis: Optional[plt.axis] = None,
label: Optional[str] = None
) -> None:
"""Single panel, multi-line plot of change in magnitude vs redshift color coded by PWV
pwv_arr: Array of PWV values
z_arr: Array of redshift values
delta_mag_arr: Array of delta mag values
axis: Optionally plot on a given axis
label: Optional label to format with PWV
if axis is None:
axis = plt.gca()
_multi_line_plot(z_arr, delta_mag_arr, pwv_arr, axis, label)
axis.set_xlabel('Redshift', fontsize=20)
axis.set_xlim(min(z_arr), max(z_arr))
axis.set_ylabel(r'$\Delta m$', fontsize=20)
[docs]def plot_delta_mag_vs_pwv(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag_arr: np.ndarray,
axis: Optional[plt.axis] = None,
label: Optional[str] = None
) -> None:
"""Single panel, multi-line plot for change in magnitude vs PWV color coded by redshift
pwv_arr: Array of PWV values
z_arr: Array of redshift values
delta_mag_arr: Array of delta mag values
axis: Optionally plot on a given axis
label: Optional label to format with redshift
if axis is None:
axis = plt.gca()
_multi_line_plot(pwv_arr, delta_mag_arr.T, z_arr, axis, label)
axis.set_xlabel('PWV', fontsize=20)
axis.set_xlim(min(pwv_arr), max(pwv_arr))
axis.set_ylabel(r'$\Delta m$', fontsize=20)
# noinspection PyUnusedLocal
[docs]def plot_derivative_mag_vs_z(
pwv_arr: np.ndarray, z_arr: np.ndarray, slope_arr: np.ndarray, axis: Optional[plt.axis] = None
) -> None:
"""Single panel, multi-line plot of slope in delta magnitude vs redshift color coded by PWV
pwv_arr: Array of PWV values
z_arr: Array of redshift values
slope_arr: Slope of delta mag at reference PWV
axis: Optionally plot on a given axis
if axis is None:
axis = plt.gca()
axis.plot(z_arr, slope_arr)
axis.set_xlabel('Redshift', fontsize=20)
axis.set_xlim(min(z_arr), max(z_arr))
axis.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{PWV = 4 mm}$', fontsize=20)
# noinspection PyUnboundLocalVariable
[docs]def plot_pwv_mag_effects(
pwv_arr: np.ndarray,
z_arr: np.ndarray,
delta_mag: dict,
slopes: np.ndarray,
bands: List[str],
figsize: Tuple[Numeric, Numeric] = (10, 8)
) -> Tuple[plt.Figure, plt.Axes]:
"""Multi panel plot with a column for each band and rows for the change in magnitude vs pwv and redshift parameters
``delta_mag`` is expected to have band names as keys, and 2d arrays as
values. Each array should represent the change in magnitude for each
given PWV and redshift
pwv_arr: PWV values used in the calculation
z_arr: Redshift values used in the calculation
delta_mag: Dictionary with delta mag for each band
slopes: Slope in delta_mag for each redshift
bands: Order of bands to plot
figsize: The size of the figure
The matplotlib figure and axis
fig, axes = plt.subplots(3, len(delta_mag), figsize=figsize)
top_reference_ax = axes[0, 0]
middle_reference_ax = axes[1, 0]
bottom_reference_ax = axes[2, 0]
# Plot data
for band, axes_column in zip(bands, axes.T):
top_ax, middle_ax, bottom_ax = axes_column
# First row
plot_delta_mag_vs_z(pwv_arr, z_arr, delta_mag[band], top_ax, label='{:g} mm')
top_ax.axhline(0, linestyle='--', color='k', label='4 mm')
top_ax.set_xlabel('Redshift', fontsize=12)
# Middle row
plot_delta_mag_vs_pwv(pwv_arr, z_arr, delta_mag[band], middle_ax, label='z = {:g}')
top_ax.axvline(4, linestyle='--', color='k')
middle_ax.set_xlabel('PWV', fontsize=12)
# Bottom row
plot_derivative_mag_vs_z(pwv_arr, z_arr, slopes[band], bottom_ax)
bottom_ax.set_xlabel('Redshift', fontsize=12)
# Share axes
top_ax.get_shared_y_axes().join(top_ax, top_reference_ax)
middle_ax.get_shared_y_axes().join(middle_ax, middle_reference_ax)
bottom_ax.get_shared_y_axes().join(bottom_ax, bottom_reference_ax)
top_ax.get_shared_x_axes().join(top_ax, top_reference_ax)
bottom_ax.get_shared_x_axes().join(bottom_ax, top_reference_ax)
top_reference_ax.autoscale() # To reset y-range
top_reference_ax.set_xlim(0.1, 1.1)
# Remove unnecessary tick marks
for axis in axes.T[1:].flatten():
# Add legends
top_ax.legend(bbox_to_anchor=(1, 1.1))
handles, labels = middle_ax.get_legend_handles_labels()
middle_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
# Add y labels
top_reference_ax.set_ylabel(r'$\Delta m \, \left(PWV,\, z\right)$', fontsize=12)
middle_reference_ax.set_ylabel(r'$\Delta m \, \left(z,\, PWV\right)$', fontsize=12)
bottom_reference_ax.set_ylabel(r'$\frac{\Delta \, m}{\Delta \, PWV} |_{4 mm}$', fontsize=12)
return fig, axes
# https://stackoverflow.com/questions/18311909/how-do-i-annotate-with-power-of-ten-formatting
[docs]def sci_notation(num: Numeric, decimal_digits: int = 1, precision: int = None, exponent: int = None) -> str:
"""Return a string representation of number in scientific notation."""
if exponent is None:
exponent = int(np.floor(np.log10(abs(num))))
coeff = round(num / float(10 ** exponent), decimal_digits)
if coeff == 1:
return r"$10^{{{}}}$".format(exponent)
if precision is None:
precision = decimal_digits
return r"${0:.{2}f}\cdot10^{{{1:d}}}$".format(coeff, exponent, precision)
[docs]def plot_spectral_template(
source: Union[str, sncosmo.Source],
wave_arr: np.ndarray,
z_arr: np.ndarray,
pwv: np.ndarray,
phase: Numeric = 0,
resolution: Numeric = 2,
figsize: Tuple[Numeric, Numeric] = (6, 4)
) -> Tuple[plt.Figure, np.array]:
"""Plot a spectral template with overlaid PWV and bandpass throughput curves
source: ``sncosmo`` source to use as spectral template
wave_arr: The observer frame wavelengths to plot flux for in Angstroms
z_arr: The redshifts to plot the template at
pwv: The PWV to plot the transmission function for
phase: The phase of the template to plot
resolution: The resolution of the atmospheric model
figsize: The size of the figure
The matplotlib figure and an array of matplotlib axes
fig, (top_ax, bottom_ax) = plt.subplots(
nrows=2, figsize=figsize, sharex='row',
gridspec_kw={'height_ratios': [4, 1.75]})
# Plot spectral template at given redshifts
model = models.SNModel(source)
flux_scale = 1e-13
for i, z in enumerate(reversed(z_arr)):
color = f'C{len(z_arr) - i - 1}'
flux = model.flux(phase, wave_arr) / flux_scale
top_ax.fill_between(wave_arr, flux, color=color, alpha=.8)
top_ax.plot(wave_arr, flux, label=f'z = {z}', color=color, zorder=0)
# Plot transmission function on twin axis at given wavelength resolution
transmission = v1_transmission(pwv, res=resolution)
twin_axis = top_ax.twinx()
twin_axis.plot(transmission.index, transmission, alpha=0.75, color='grey')
# Plot the band passes
for b in 'rizy':
band = sncosmo.get_bandpass(f'lsst_hardware_{b}')
bottom_ax.plot(band.wave, band.trans, label=f'{b} Band')
# Format top axis
top_ax.set_ylim(0, 5)
top_ax.set_xlim(min(wave_arr), max(wave_arr))
top_ax.legend(loc='lower left', framealpha=1)
# Format twin axis
twin_axis.set_ylim(0, 1)
twin_axis.set_ylabel('Transmission', rotation=-90, labelpad=12)
# Format bottom axis
bottom_ax.set_ylim(0, 1)
bottom_ax.set_xlabel(r'Wavelength $\AA$')
bottom_ax.set_xticks(np.arange(4000, 11001, 2000))
bottom_ax.legend(loc='lower left', framealpha=1)
return fig, np.array([top_ax, bottom_ax])
[docs]def plot_spectrum(
wave: np.array,
flux: np.array,
figsize: Tuple[Numeric, Numeric] = (9, 3),
) -> Tuple[plt.figure, plt.Axes]:
"""Plot a spectrum over the per-filter LSST hardware throughput
wave: Spectrum wavelengths in Angstroms
flux: Flux of the spectrum in arbitrary units
figsize: Size of the figure
hardware_only: Only include hardware contributions in the plotted filters
The matplotlib figure and axis
fig, axis = plt.subplots(figsize=figsize)
axis.set_ylabel('Object Flux')
axis.set_xlim(min(wave), max(wave))
axis.set_xlabel(r'Wavelength ($\AA$)')
twin_ax = axis.twinx()
twin_ax.set_ylim(0, 1)
twin_ax.set_ylabel('Bandpass Transmission', rotation=270, labelpad=15)
prefix = 'lsst_hardware_' if hardware_only else 'lsst_total_'
for filter_abbrev in 'ugrizy':
bandpass = sncosmo.get_bandpass(f'{prefix}{filter_abbrev}')
twin_ax.fill_between(wave, bandpass(wave), alpha=.3, label=filter_abbrev)
twin_ax.plot(wave, bandpass(wave))
axis.plot(wave, flux, color='k')
return fig, axis
# noinspection PyUnboundLocalVariable
[docs]def plot_magnitude(
mags: Dict[str, np.ndarray], pwv: np.ndarray, z: np.ndarray, figsize: Tuple[Numeric, Numeric] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
"""Multi-panel plot showing magnitudes in different columns vs PWV and redshift in different rows
mags: Simulated magnitude values for each band
pwv: Array of PWV values
z: Array of redshift values
figsize: Size of the figure
The matplotlib figure and axis
fig, axes = plt.subplots(2, len(mags), figsize=figsize, sharey='row')
for (band, mag_arr), (top_ax, bottom_ax) in zip(mags.items(), axes.T):
_multi_line_plot(z, mag_arr, pwv, top_ax, label='{:g} mm')
_multi_line_plot(pwv, mag_arr.T, z, bottom_ax, label='z = {:.2f}')
# Add legends
top_ax.legend(bbox_to_anchor=(1, 1.1))
handles, labels = bottom_ax.get_legend_handles_labels()
bottom_ax.legend(handles[::5], labels[::5], bbox_to_anchor=(1, 1.1))
return fig, axes
[docs]def plot_fitted_params(
fitted_params: Dict[str, np.ndarray], pwv_arr: np.ndarray, z_arr: np.ndarray, bands: List[str]
) -> Tuple[plt.Figure, np.array]:
"""Multi-panel plot showing subplots for each salt2 parameter vs redshift.
Multiple lines included for different PWV values.
fitted_params: Dictionary with fitted parameters in each band
pwv_arr: PWV value used for each supernova fit
z_arr: Redshift value used for each supernova fit
bands: Bands to include in the plot. Must be keys of ``fitted_params``
The matplotlib figure and an array of matplotlib axes
# Parse the fitted parameters for easier plotting
model = sncosmo.Model('salt2-extended')
params_dict = {
param: fitted_params[bands[0]][..., i] for
i, param in enumerate(model.param_names)
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
for axis, (param, param_vals) in zip(axes.flatten(), params_dict.items()):
if param == 'x0':
param_vals = -2.5 * np.log10(param_vals)
param = r'-2.5 log$_{10}$(x$_{0}$)'
_multi_line_plot(z_arr, param_vals, pwv_arr, axis, label='z = {:g}')
correction_factor = const.betoule_alpha * params_dict['x1'] - const.betoule_beta * params_dict['c']
_multi_line_plot(z_arr, correction_factor, pwv_arr, axes[-1][-1], label='PWV = {:g} mm')
label = f'{const.betoule_alpha} * $x_1$ - {const.betoule_beta} * $c$'
axes[-1][-1].legend(bbox_to_anchor=(1, 1.1))
return fig, axes
# noinspection PyUnboundLocalVariable
[docs]def plot_delta_colors(
pwv_arr: np.ndarray, z_arr: np.ndarray, mag_dict: Dict[str, np.ndarray],
colors: List[Tuple[str, str]], ref_pwv: Numeric = 0
) -> None:
"""Shows the change in SN color as a function of redshift with each SN color coded by PWV
pwv_arr: Array of PWV values
z_arr: Array of redshift values
mag_dict: Dictionary with magnitudes for each band
colors: Band combinations to plot colors for
ref_pwv: Plot values relative to given reference PWV
num_cols = len(colors)
fig, axes = plt.subplots(ncols=num_cols, figsize=(4 * num_cols, 4), sharey='col')
if num_cols == 1:
axes = [axes]
ref_idx = list(pwv_arr).index(ref_pwv)
for axis, (band1, band2) in zip(axes, colors):
color = mag_dict[band1] - mag_dict[band2]
delta_color = color - color[ref_idx]
_multi_line_plot(z_arr, delta_color, pwv_arr, label='{} mm', axis=axis)
axis.set_xlabel('Redshift', fontsize=14)
axis.set_ylabel(fr'$\Delta$ ({band1[-1]} - {band2[-1]})', fontsize=14)
axis.set_xlim(0, max(z_arr))
axis.legend(bbox_to_anchor=(1, 1.1))
# noinspection PyUnusedLocal
[docs]def plot_delta_mu(
mu: np.ndarray, pwv_arr: np.ndarray, z_arr: np.ndarray, cosmo: Cosmology = const.betoule_cosmo
) -> None:
"""Plot the variation in fitted distance modulus as a function of redshift and PWV
mu: Array of distance moduli
pwv_arr: Array of PWV values
z_arr: Array of redshift values
cosmo: Astropy cosmology to compare results against
cosmo_mu = cosmo.distmod(z_arr).value
delta_mu = mu - cosmo_mu
fig, axes = plt.subplots(ncols=3, figsize=(9, 3))
mu_ax, delta_mu_ax, relative_mu_ax = axes
_multi_line_plot(z_arr, mu, pwv_arr, mu_ax)
mu_ax.plot(z_arr, cosmo_mu, linestyle=':', color='k', label='Simulated')
_multi_line_plot(z_arr, delta_mu, pwv_arr, delta_mu_ax)
delta_mu_ax.axhline(0, linestyle=':', color='k', label='Simulated')
_multi_line_plot(z_arr, mu - mu[4], pwv_arr, relative_mu_ax, label='{:g} mm')
relative_mu_ax.axhline(0, color='k', label=f'PWV={pwv_arr[4]}')
relative_mu_ax.legend(framealpha=1, bbox_to_anchor=(1, 1.1))
mu_ax.set_ylabel(r'$\mu$', fontsize=12)
delta_mu_ax.set_ylabel(r'$\mu - \mu_{cosmo}$', fontsize=12)
relative_mu_ax.set_ylabel(r'$\mu - \mu_{pwv_f}$', fontsize=12)
for ax in axes:
ax.set_xlabel('Redshift', fontsize=12)
[docs]def plot_year_pwv_vs_time(
pwv_series: pd.Series, figsize: Tuple[Numeric, Numeric] = (8, 4), missing: Numeric = 1
) -> Tuple[plt.figure, plt.Axes]:
"""Plot PWV measurements taken over a single year as a function of time.
Set ``missing=None`` to disable plotting of missing data windows.
pwv_series: Measured PWV index by datetime
figsize: Size of the figure in inches
missing: Highlight time ranges larger than given number of days with missing PWV
The matplotlib figure and axis
pwv_series = pwv_series.sort_index()
# Calculate rolling average of PWV time series
rolling_mean_pwv = pwv_series.rolling(window='7D').mean()
# In practice these dates vary by year so the values used here are approximate
year = pwv_series.index[0].year
mar_equinox = datetime(year, 3, 20, tzinfo=utc)
jun_equinox = datetime(year, 6, 21, tzinfo=utc)
sep_equinox = datetime(year, 9, 22, tzinfo=utc)
dec_equinox = datetime(year, 12, 21, tzinfo=utc)
# Separate data based on season
summer_pwv = pwv_series[(pwv_series.index < mar_equinox) | (pwv_series.index > dec_equinox)]
fall_pwv = pwv_series[(pwv_series.index > mar_equinox) & (pwv_series.index < jun_equinox)]
winter_pwv = pwv_series[(pwv_series.index > jun_equinox) & (pwv_series.index < sep_equinox)]
spring_pwv = pwv_series[(pwv_series.index > sep_equinox) & (pwv_series.index < dec_equinox)]
print(f'Summer Average: {summer_pwv.mean(): .2f} +\\- {summer_pwv.std(): .2f} mm')
print(f'Fall Average: {fall_pwv.mean(): .2f} +\\- {fall_pwv.std(): .2f} mm')
print(f'Winter Average: {winter_pwv.mean(): .2f} +\\- {winter_pwv.std(): .2f} mm')
print(f'Spring Average: {spring_pwv.mean(): .2f} +\\- {spring_pwv.std(): .2f} mm')
fig, axis = plt.subplots(figsize=figsize)
axis.set_ylabel('PWV (mm)')
axis.set_xlabel('Time of Year')
axis.set_ylim(0, 20)
ylow, yhigh = axis.get_ylim()
# Plot windows of missing pwv
if missing:
missing_interval = timedelta(days=missing)
year_start = datetime(year, 1, 1, tzinfo=utc)
year_end = datetime(year, 12, 31, 23, 59, 59, tzinfo=utc)
delta_t = pwv_series.index[1:] - pwv_series.index[:-1]
start_indices = np.where(delta_t > missing_interval)[0]
for index in start_indices:
start_time = pwv_series.index[index]
end_time = pwv_series.index[index + 1]
x=[start_time, end_time], y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
if (pwv_series.index[0] - year_start) > missing_interval:
x=[year_start, pwv_series.index[0]],
y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
if (year_end - pwv_series.index[-1]) > missing_interval:
x=[pwv_series.index[-1], year_end],
y1=[ylow, ylow], y2=[yhigh, yhigh],
color='lightgrey', alpha=0.5, zorder=0)
# Plot measured PWV
for equinox_date in (mar_equinox, jun_equinox, sep_equinox, dec_equinox):
axis.axvline(equinox_date, linestyle='--', color='k', zorder=1)
# Plot rolling average
axis.scatter(pwv_series.index, pwv_series, s=1, alpha=.2, zorder=2)
axis.plot(rolling_mean_pwv.index, rolling_mean_pwv, color='C1', label='Rolling Avg.', zorder=3, linewidth=2)
# Plot seasonal average
# winter is plotted separately because it spans the new year
if not winter_pwv.empty:
winter_avg = winter_pwv.mean()
winter_std = winter_pwv.std()
winter_subset = pwv_series[pwv_series.index < mar_equinox]
winter_x = winter_subset.index.max() - (winter_subset.index.max() - winter_subset.index.min()) / 2
axis.errorbar([winter_x], [winter_avg], yerr=[winter_std], color='k', zorder=4, linewidth=2, capsize=10,
axis.scatter([winter_x], [winter_avg], color='k', s=100, marker='+', zorder=4, label='Seasonal Avg.')
for season in (spring_pwv, summer_pwv, fall_pwv):
if season.empty:
avg = season.mean()
std = season.std()
x = season.index.max() - (season.index.max() - season.index.min()) / 2
axis.errorbar([x], [avg], yerr=[std], color='k', zorder=4, linewidth=2, capsize=10, capthick=2)
axis.scatter([x], [avg], color='k', s=100, marker='+', zorder=4)
# Format x labels to be three letter month abbreviations
locator = mdates.MonthLocator()
formatter = mdates.DateFormatter('%b')
return fig, axis
# Ignore arguments with uppercase letters
# noinspection PyPep8Naming
[docs]def plot_cosmology_fit(
data: pd.DataFrame, abs_mag: Numeric, H0: Numeric, Om0: Numeric, w0: Numeric, alpha: Numeric, beta: Numeric
) -> Tuple[plt.figure, np.ndarray, np.ndarray]:
"""Plot a cosmological fit to a set of supernova data.
data: Results from the snat_sim fitting pipeline
abs_mag: Intrinsic absolute magnitude of SNe Ia
H0: Fitted Hubble constant at z = 0 in [km/sec/Mpc]
Om0: Omega matter density in units of the critical density at z=0
w0: Dark energy equation of state
alpha: Fitted nuisance parameter for supernova stretch correction
beta: Fitted nuisance parameter for supernova color correction
The matplotlib figure, fitted distance modulus, and tabulated residuals
data = data.sort_values('z')
fitted_mu = FlatwCDM(H0=H0, Om0=Om0, w0=w0).distmod(data.z).value
measured_mu = data.snat_sim.calc_distmod(abs_mag) + alpha * data.x1 - beta * data.c
residuals = measured_mu - fitted_mu
fig, (top_ax, bottom_ax) = plt.subplots(2, sharex='col', gridspec_kw={'height_ratios': [2, 1]})
top_ax.errorbar(data.z, measured_mu, yerr=data.mb_err, linestyle='')
top_ax.scatter(data.z, measured_mu, s=1)
top_ax.plot(data.z, fitted_mu, color='k', alpha=.75)
bottom_ax.axhline(0, color='k', alpha=.75, linestyle='--')
bottom_ax.errorbar(data.z, residuals, yerr=data.mb_err, linestyle='')
bottom_ax.scatter(data.z, residuals, s=1)
# Style the plot
top_ax.set_ylabel(r'$\mu = m^*_B - M_B + \alpha x_1 - \beta c$')
bottom_ax_lim = max(np.abs(bottom_ax.get_ylim()))
bottom_ax.set_ylim(-bottom_ax_lim, bottom_ax_lim)
return fig, fitted_mu, residuals
[docs]def plot_residuals_on_sky(
ra: np.array,
dec: np.array,
residual: np.array,
cmap: str = 'coolwarm',
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Plot hubble residuals as a function of supernova coordinates.
ra: Right Ascension for each supernova
dec: Declination of each supernova
residual: Hubble residual for each supernova
cmap: Name of the matplotlib color map to use
figsize: The size of the figure
The matplotlib figure and axis
sn_coord = SkyCoord(ra, dec, unit=u.deg).galactic
fig, axis = plt.subplots(figsize=figsize, subplot_kw={'projection': 'aitoff'})
vlim = max(np.abs(residual))
scat = axis.scatter(
sn_coord.l.wrap_at('180d').radian, sn_coord.b.radian,
c=residual, cmap=cmap, vmin=-vlim, vmax=vlim)
plt.colorbar(scat).set_label('Hubble Residual', rotation=270, labelpad=15)
return fig, axis
[docs]def compare_prop_effects(
pwv_data: pd.Series,
static: models.StaticPWVTrans,
seasonal: models.SeasonalPWVTrans,
variable: models.VariablePWVTrans,
figsize: Tuple[float, float] = (9, 6)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the Zenith PWV assumed by different propagation effects
pwv_data: Series with PWV values and a Datetime index
static: Static propagation effect
seasonal: Seasonal Propagation effect
variable: Variable Propagation effect
figsize: The size of the figure
The matplotlib figure and axis
variable = copy(variable)
# Disable airmass scaling so we get values at zenith
# noinspection PyProtectedMember
variable._pwv_model.calc_airmass = lambda *args, **kwargs: 1
pwv_data = pwv_data.sort_index()
x_vals = np.arange(pwv_data.index[0], pwv_data.index[-1], timedelta(days=1)).astype(datetime)
mjd_vals = Time(x_vals).mjd
plt.scatter(pwv_data.index, pwv_data.values, s=2, alpha=.1, color='grey')
plt.plot(x_vals, variable.assumed_pwv(mjd_vals), label='Variable')
plt.plot(x_vals, seasonal.assumed_pwv(mjd_vals), label='Seasonal', linewidth=2.5)
plt.plot([x_vals[0], x_vals[-1]], [static['pwv'], static['pwv']], label='Static', linewidth=2.5)
plt.legend(loc='upper right', framealpha=1)
plt.ylabel('PWV (mm)')
plt.xlabel('Date (UTC)')
plt.ylim(0, 25)
return plt.gcf(), plt.gca()
[docs]def plot_transmission_variation(
pwv1: float, pwv2: float,
wave_min: float = 6500,
wave_max: float = 10000,
resolution: int = 9,
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the atmospheric transmission function for two PWV concentrations
pwv1: The first PWV concentration to plot the transmission for
pwv2: The second PWV concentration to plot the transmission for
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
resolution: Bin the atmospheric transmission function to a lower transmission
figsize: The size of the figure
The matplotlib figure and axis
low_pwv = min(pwv1, pwv2)
high_pwv = max(pwv1, pwv2)
wave = np.arange(wave_min, wave_max)
low_transmission = v1_transmission(pwv=low_pwv, wave=wave, res=resolution)
high_transmission = v1_transmission(pwv=high_pwv, wave=wave, res=resolution)
fig, axis = plt.subplots(figsize=figsize)
axis.plot(wave, low_transmission, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
axis.fill_between(wave, low_transmission, high_transmission, label=f'PWV = {high_pwv} mm', alpha=.75)
axis.set_title(r'Change in atmospheric PWV transmission flux due to $\Delta$PWV')
axis.set_xlabel(r'Wavelength ($\AA$)')
axis.set_ylabel('PWV Transmission')
axis.set_xlim(wave_min, wave_max)
axis.set_ylim(.5, 1)
return fig, axis
[docs]def plot_flux_variation(
pwv1: float,
pwv2: float,
z: float = .55,
wave_min: float = 6500,
wave_max: float = 10000,
resolution: int = 9,
figsize: Tuple[Numeric, Numeric] = (8, 4)
) -> Tuple[plt.figure, plt.Axes]:
"""Compare the PWV absorbed flux of a SN IA for two PWV concentrations
pwv1: The first PWV concentration to plot the flux for
pwv2: The second PWV concentration to plot the flux for
z: The redshift of the SN Ia
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
resolution: Bin the atmospheric transmission function to a lower transmission
figsize: The size of the figure
The matplotlib figure and axis
low_pwv = min(pwv1, pwv2)
high_pwv = max(pwv1, pwv2)
model = models.SNModel('salt2-extended')
model.add_effect(models.StaticPWVTrans(transmission_res=resolution), '', 'obs')
wave = np.arange(wave_min, wave_max)
model.set(z=z, pwv=low_pwv)
low_pwv_flux = model.flux(0, wave)
high_pwv_flux = model.flux(0, wave)
fig, axis = plt.subplots(figsize=figsize)
axis.plot(wave, low_pwv_flux, color='k', label=f'PWV = {low_pwv} mm', linewidth=1.5)
axis.fill_between(wave, low_pwv_flux, high_pwv_flux, label=f'PWV = {high_pwv} mm', alpha=.75)
axis.set_title(r'Change in SN Ia flux due to $\Delta$PWV')
axis.set_xlabel(r'Wavelength ($\AA$)')
axis.set_ylabel('SN Ia Flux')
axis.set_xlim(wave_min, wave_max)
return fig, axis
[docs]def plot_delta_sn_flux(
pwv: float = 4, wave_min: float = 6500, wave_max: float = 10000, figsize: Tuple[Numeric, Numeric] = (8, 6)
) -> Tuple[plt.figure, np.array]:
"""Plot the change to spectroscopic SN Ia flux over wavelength and redshift
An imshow style plot with the change in flux along the color axis.
pwv: The PWV concentration to use when determining the change in flux
wave_min: Minimum wavelength to plot
wave_max: Maximum wavelength to plot
figsize: The size of the figure
The matplotlib figure and and array of matplotlib axes
sn_model = models.SNModel('salt2-extended')
sn_model.add_effect(models.StaticPWVTrans(transmission_res=10), '', 'obs')
transmission_model = PWVTransmissionModel(resolution=5)
delta_flux = []
wave = np.arange(wave_min, wave_max)
for z in np.arange(0.0001, 1.01, .1):
sn_model.set(z=z, pwv=pwv)
flux = sn_model.flux(0, wave)
flux_model = sn_model.flux(0, wave)
delta_flux.append(flux - flux_model)
fig, (top_axis, middle_axis, bottom_ax) = plt.subplots(
nrows=3, figsize=figsize, sharex='col',
gridspec_kw={'height_ratios': [1.75, 4, 1.75]})
wave_arr = np.arange(wave_min, wave_max)
trans = transmission_model.calc_transmission(pwv, wave_arr)
top_axis.plot(wave_arr, 100 * trans)
extent=[wave_min, wave_max, 0, 1],
# Plot the band passes
for b in 'rizy':
band = sncosmo.get_bandpass(f'lsst_total_{b}')
bottom_ax.plot(band.wave, band.trans, label=f'{b}')
# Format each axis
top_axis.set_title(f'Change in spectral SN Ia flux due to PWV ({pwv} mm)')
top_axis.set_ylabel('Transmission (%)')
middle_axis.set_ylabel('Redshift (z)')
bottom_ax.set_ylim(0, 1)
bottom_ax.set_xlim(wave_min, wave_max)
bottom_ax.set_xlabel(r'Wavelength $\AA$')
bottom_ax.set_xticks(np.arange(wave_min, wave_max + 1, 500))
bottom_ax.set_yticks([0, .25, .5, .75])
bottom_ax.legend(loc='lower left', framealpha=1)
return fig, np.array([middle_axis, bottom_ax])