"""Defines the individual data processing nodes used to construct complete
data analysis pipelines.
.. note:: Nodes are built on the ``egon`` framework. For more information see the
official `Egon Documentation <https://mwvgroup.github.io/Egon/>`_.
Module Docs
-----------
"""
from __future__ import annotations
import warnings
from copy import copy
from pathlib import Path
from typing import *
import numpy as np
import pandas as pd
from astropy.cosmology.core import Cosmology
from egon.connectors import Input, Output
from egon.nodes import Node, Source, Target
from .. import constants as const
from ..models import LightCurve, ObservedCadence, SNFitResult, SNModel, VariableCatalog
from ..pipeline.data_model import PipelinePacket
from ..plasticc import PLAsTICC
[docs]class LoadPlasticcCadence(Source):
"""Pipeline node for loading PLAsTICC cadence data from disk
Connectors:
output: Emits a pipeline packet decorated with the snid, simulation parameters, and cadence
"""
[docs] def __init__(
self,
plasticc_dao: PLAsTICC,
iter_lim: int = float('inf'),
override_zp: float = 30,
verbose: bool = True,
num_processes: int = 1
) -> None:
"""Source node for loading PLAsTICC cadence data from disk
This node can only be run using a single process. This can be the main
process (``num_processes=0``) or a single forked process (``num_processes=1``.)
Args:
plasticc_dao: A PLAsTICC data access object
iter_lim: Exit after loading the given number of light-curves
override_zp: Overwrite the zero-point used by plasticc with this number
verbose: Display a progress bar
num_processes: Number of processes to allocate to the node (must be 0 or 1 for this node)
"""
if num_processes not in (0, 1):
raise RuntimeError('Number of processes for ``LoadPlasticcCadence`` must be 0 or 1.')
self.cadence = plasticc_dao
self.iter_lim = iter_lim
self.override_zp = override_zp
self.verbose = verbose
# Node connectors
self.output = Output('Loading Cadence Output')
super().__init__(num_processes=num_processes)
[docs] def action(self) -> None:
"""Load PLAsTICC cadence data from disk"""
for snid, params, cadence in self.cadence.iter_cadence(iter_lim=self.iter_lim, verbose=self.verbose):
cadence.zp = np.full_like(cadence.zp, self.override_zp)
self.output.put(PipelinePacket(snid, params, cadence))
[docs]class SimulateLightCurves(Node):
"""Pipeline node for simulating light-curves based on PLAsTICC cadences
Connectors:
input: A Pipeline Packet
success_output: Emits pipeline packets successfully decorated with a simulated light-curve
failure_output: Emits pipeline packets for cases where the simulation procedure failed
"""
[docs] def __init__(
self,
sn_model: SNModel,
catalog: VariableCatalog = None,
num_processes: int = 1,
add_scatter: bool = True,
fixed_snr: Optional[float] = None,
abs_mb: float = const.betoule_abs_mb,
cosmo: Cosmology = const.betoule_cosmo
) -> None:
"""Fit light-curves using multiple processes and combine results into an output file
Args:
sn_model: Model to use when simulating light-curves
catalog: Optional reference start catalog to calibrate simulated flux values to
num_processes: Number of processes to allocate to the node
abs_mb: The absolute B-band magnitude of the simulated SNe
cosmo: Cosmology to assume in the simulation
"""
self.sim_model = copy(sn_model)
self.catalog = catalog
self.add_scatter = add_scatter
self.fixed_snr = fixed_snr
self.abs_mb = abs_mb
self.cosmo = cosmo
# Node connectors
self.input = Input('Simulated Cadence')
self.success_output = Output('Simulation Success')
self.failure_output = Output('Simulation Failure')
super().__init__(num_processes=num_processes)
[docs] def simulate_lc(self, params: Dict[str, float], cadence: ObservedCadence) -> Tuple[LightCurve, SNModel]:
"""Duplicate a plastic light-curve using the simulation model
Args:
params: The simulation parameters to use with ``self.model``
cadence: The observed cadence of the returned light-curve
"""
# Set model parameters and scale the source brightness to the desired intrinsic brightness
model_for_sim = copy(self.sim_model)
model_for_sim.update({p: v for p, v in params.items() if p in model_for_sim.param_names})
model_for_sim.set_source_peakabsmag(self.abs_mb, 'standard::b', 'AB', cosmo=self.cosmo)
# Simulate the light-curve. Make sure to include model parameters as meta data
duplicated = model_for_sim.simulate_lc(cadence, scatter=self.add_scatter, fixed_snr=self.fixed_snr)
# Rescale the light-curve using the reference star catalog if provided
if self.catalog is not None:
duplicated = self.catalog.calibrate_lc(duplicated, ra=params['ra'], dec=params['dec'])
return duplicated, model_for_sim
[docs] def action(self) -> None:
"""Simulate light-curves with atmospheric effects"""
for packet in self.input.iter_get():
try:
light_curve, model = self.simulate_lc(packet.sim_params, packet.cadence)
except Exception as excep:
packet.message = f'{self.__class__.__name__}: {repr(excep)}'
self.failure_output.put(packet)
else:
packet.light_curve = light_curve
packet.sim_params['x0'] = model['x0']
self.success_output.put(packet)
[docs]class FitLightCurves(Node):
"""Pipeline node for fitting simulated light-curves
Connectors:
input: A Pipeline Packet
success_output: Emits pipeline packets with successful fit results
failure_output: Emits pipeline packets for cases where the fitting procedure failed
"""
[docs] def __init__(
self, sn_model: SNModel, vparams: List[str], bounds: Dict = None, num_processes: int = 1
) -> None:
"""Fit light-curves using multiple processes and combine results into an output file
Args:
sn_model: Model to use when fitting light-curves
vparams: List of parameter names to vary in the fit
bounds: Bounds to impose on ``fit_model`` parameters when fitting light-curves
num_processes: Number of processes to allocate to the node
"""
self.sn_model = sn_model
self.vparams = vparams
self.bounds = bounds
# Node Connectors
self.input = Input('Simulated Light-Curve')
self.success_output = Output('Fitting Success')
self.failure_output = Output('Fitting Failure')
super(FitLightCurves, self).__init__(num_processes=num_processes)
[docs] def fit_lc(self, light_curve: LightCurve, initial_guess: Dict[str, float]) -> SNFitResult:
"""Fit the given light-curve
Args:
light_curve: The light-curve to fit
initial_guess: Parameters to use as the initial guess in the chi-squared minimization
Returns:
- The optimization result
- A copy of the model with parameter values set to minimize the chi-square
"""
# Use the true light-curve parameters as the initial guess
model = copy(self.sn_model)
model.update({k: v for k, v in initial_guess.items() if k in self.sn_model.param_names})
# Ensure any bounds applied to `t0` are applied relative to the simulated value
bounds = copy(self.bounds)
if bounds and bounds.get('t0', None):
model_t0 = model['t0']
lower_t0, upper_t0 = bounds['t0']
bounds['t0'] = (model_t0 + lower_t0, model_t0 + upper_t0)
return model.fit_lc(
light_curve, self.vparams, bounds=bounds,
guess_t0=False, guess_amplitude=False, guess_z=False)
[docs] def action(self) -> None:
"""Fit light-curves"""
for packet in self.input.iter_get():
try:
packet.fit_result = self.fit_lc(packet.light_curve, packet.sim_params)
packet.covariance = packet.fit_result.salt_covariance_linear()
except Exception as excep:
packet.message = f'{self.__class__.__name__}: {repr(excep)}'
self.failure_output.put(packet)
else:
packet.message = f'{self.__class__.__name__}: {packet.fit_result.message}'
self.success_output.put(packet)
[docs]class WritePipelinePacket(Target):
"""Pipeline node for writing pipeline packets to disk
Connectors:
input: A pipeline packet
"""
[docs] def __init__(self, out_path: Union[str, Path], write_lc_sims: bool = False, num_processes=1) -> None:
"""Output node for writing HDF5 data to disk
This node can only be run using a single process.
Args:
out_path: Path to write data to in HDF5 format
write_lc_sims: Whether to include simulated light-curves in the data written to disk
"""
# Make true to raise errors instead of converting them to warnings
self.input = Input('Data To Write')
self.write_lc_sims = write_lc_sims
self.debug = False
self.out_path = Path(out_path)
self.file_store: Optional[pd.HDFStore] = None
self._num_results_per_file = 10_000
self._num_results_in_current_file = 0
self._current_file_id = 0
super().__init__(num_processes=num_processes)
def _rotate_output_file(self) -> None:
"""Have the running process close the current output file and start writing to a new one
Once files get too large the write performance starts to suffer.
We address this by closing the current file, incrementing
a number in the output file path, and writing data to that new path
"""
if self._num_results_in_current_file < self._num_results_per_file:
return
if self.file_store is not None:
self.file_store.close()
# Update output file path
old_id = self._current_file_id
self._current_file_id += 1
new_stem = self.out_path.stem.replace(f'_fn{old_id}', f'_fn{self._current_file_id}')
self.out_path = self.out_path.with_stem(new_stem)
# noinspection PyTypeChecker
self.file_store = pd.HDFStore(self.out_path, mode='w')
self._num_results_in_current_file = 0
def _write_packet(self, packet: PipelinePacket) -> None:
"""Write a pipeline packet to the output file"""
self._rotate_output_file()
# We are taking the simulated parameters as guaranteed to exist
self.file_store.append('simulation/params', packet.sim_params_to_pandas())
self.file_store.append('message', packet.packet_status_to_pandas(), min_itemsize={'message': 250})
if self.write_lc_sims and packet.light_curve is not None:
self.file_store.put(f'simulation/lcs/{packet.snid}', packet.light_curve.to_pandas())
if packet.fit_result is not None:
self.file_store.append('fitting/params', packet.fitted_params_to_pandas())
if packet.covariance is not None:
self.file_store.put(f'fitting/covariance/{packet.snid}', packet.covariance)
self._num_results_in_current_file += 1
[docs] def setup(self) -> None:
"""Open a file accessor object"""
# If we are writing data to disk in parallel, add the process id to
# prevent multiple processes writing to the same file
if self.num_processes > 1:
import multiprocessing
pid = hex(id(multiprocessing.current_process())) # Use hex for shorter filename
self.out_path = self.out_path.with_suffix(f'.{pid}.h5')
self.out_path = self.out_path.with_stem(self.out_path.stem + f'_fn{self._current_file_id}')
# noinspection PyTypeChecker
self.file_store = pd.HDFStore(self.out_path, mode='w')
[docs] def teardown(self) -> None:
"""Close any open file accessors"""
self.file_store.close()
self.file_store = None
[docs] def action(self) -> None:
"""Write data from the input connector to disk"""
for packet in self.input.iter_get():
try:
self._write_packet(packet)
except Exception as excep:
if self.debug:
raise
warnings.warn(f'{self.__class__.__name__}: {repr(excep)}')