# Set mpl backend run plots on github actions
import os
import matplotlib as mpl
if os.environ.get('DISPLAY','') == 'test':
print('no display found. Using non-interactive Agg backend')
mpl.use('Agg')
import pylab as plt
import numpy as np
from scipy.interpolate import interp2d
from matplotlib.ticker import ScalarFormatter, NullFormatter
from ..utils import none_val, hp, updated_dict
########################################################################
########## Monkeypatching matplotlib ###################################
########################################################################
from ..utils import smooth_line
def _plot_smooth(self, *args, scheme=[1, 2, 1], n_increase=0, **kwargs):
"""Function to apply loops in plots.
Parameters
----------
self: class
To be used by mpl
*args
Main function positional arguments
scheme: list
Scheme to be used for smoothening. Newton's binomial coefficients work better.
n_increase: int
Number of loops for the algorithm.
*8kwargs
main function keyword arguments
Returns
-------
output of self.plot
"""
return self.plot(
*smooth_line(*np.array(args[:2]), scheme=scheme, n_increase=n_increase),
*args[2:], **kwargs)
plt.plot_smooth = lambda *args , **kwargs: _plot_smooth(*args, **kwargs)
plt.Axes.plot_smooth = lambda *args , **kwargs: _plot_smooth(*args, **kwargs)
########################################################################
########################################################################
########################################################################
[docs]def add_grid(ax, major_lw=0.5, minor_lw=0.1):
"""
Adds a grid to ax
Parameters
----------
ax: matplotlib.axes
Ax to add plot
major_lw: float
Line width of major axes
minor_lw: float
Line width of minor axes
"""
ax.xaxis.grid(True, which='major', lw=major_lw)
ax.yaxis.grid(True, which='major', lw=major_lw)
ax.xaxis.grid(True, which='minor', lw=minor_lw)
ax.yaxis.grid(True, which='minor', lw=minor_lw)
[docs]def plot_hist_line(hist_values, bins, ax, shape='steps', rotate=False, **kwargs):
"""
Plot recovey rate as lines. Can be in steps or continuous
Parameters
----------
hist_values: array
Values of each bin in the histogram
bins: array, int
Bins of histogram
ax: matplotlib.axes
Ax to add plot
shape: str
Shape of the line. Can be steps or line.
rotate: bool
Invert x-y axes in plot
kwargs: parameters
Additional parameters for plt.plot.
It also includes the possibility of smoothening the line with `n_increase, scheme`
arguments. See `clevar.utils.smooth_line` for details.
"""
if shape=='steps':
data = (np.transpose([bins[:-1], bins[1:]]).flatten(),
np.transpose([hist_values, hist_values]).flatten())
elif shape=='line':
data = (0.5*(bins[:-1]+bins[1:]), hist_values)
else:
raise ValueError(f"shape ({shape}) must be 'steps' or 'line'")
if rotate:
data = data[::-1]
ax.plot_smooth(*data, **kwargs)
[docs]def get_bin_label(edge_lower, edge_higher,
format_func=lambda v:v,
prefix=''):
"""
Get label with bin range
Parameters
----------
edge_lower: float
Lower values of bin
edge_higher: float
Higher values of bin
format_func: function
Function to format the values of the bins
prefix: str
Prefix to add to labels
Returns
-------
srt
Label of bin
"""
return f'${prefix}[{format_func(edge_lower)}$ : ${format_func(edge_higher)}]$'
[docs]def add_panel_bin_label(axes, edges_lower, edges_higher,
format_func=lambda v:v, prefix=''):
"""
Adds label with bin range on the top of panel
Parameters
----------
axes: matplotlib.axes
Axes with the panels
edges_lower: array
Lower values of bins
edges_higher: array
Higher values of bins
format_func: function
Function to format the values of the bins
prefix: str
Prefix to add to labels
"""
for ax, vb, vt in zip(axes.flatten(), edges_lower, edges_higher):
topax = ax.twiny()
topax.set_xticks([])
topax.set_xlabel(get_bin_label(vb, vt, format_func, prefix))
[docs]def get_density_colors(x, y, xbins, ybins, ax_rotation=0, rotation_resolution=30, xscale='linear',
yscale='linear'):
"""
Get colors of point based on density
Parameters
----------
x: array
Values for x coordinate
y: array
Values for y coordinate
xbins: array, int
Bins for x
ybins: array, int
Bins for y
ax_rotation: float
Angle (in degrees) for rotation of axis of binning. Overwrites use of xbins, ybins
rotation_resolution: int
Number of bins to be used when ax_rotation!=0.
xscale: str
Scale xaxis.
yscale: str
Scale yaxis.
Returns
-------
ndarray
Density value at location of each point
"""
# Rotated points around anlgle
sr, cr = np.sin(np.radians(ax_rotation)), np.cos(np.radians(ax_rotation))
scalefuncs = {'linear': lambda x:x, 'log': lambda x: np.log10(x)}
x2, y2 = scalefuncs[xscale](x), scalefuncs[yscale](y)
x2 = np.array(x2)*cr-np.array(y2)*sr
y2 = np.array(x2)*sr+np.array(y2)*cr
if ax_rotation == 0:
bins = (xbins, ybins)
else:
bins = (np.linspace(x2.min(), x2.max(), rotation_resolution),
np.linspace(y2.min(), y2.max(), rotation_resolution))
# Compute 2D rotated histogram
hist, xedges, yedges = np.histogram2d(x2, y2, bins=bins)
hist = hist.T
# Interpolate histogram
xm = .5*(xedges[:-1]+ xedges[1:])
ym = .5*(yedges[:-1]+ yedges[1:])
fz = interp2d(xm, ym, hist, kind='cubic')
return np.array([fz(*coord)[0] for coord in zip(x2, y2)])
[docs]def nice_panel(axes, xlabel=None, ylabel=None, xscale='linear', yscale='linear'):
"""
Add nice labels and ticks to panel plot
Parameters
----------
axes: array
Axes with the panels
bins1: array, int
Bins for component 1
bins2: array, int
Bins for component 2
Other parameters
----------------
ax: matplotlib.axes
Ax to add plot
xlabel: str
Label of x axis.
ylabel: str
Label of y axis.
xscale: str
Scale xaxis.
yscale: str
Scale yaxis.
Returns
-------
fig: matplotlib.figure.Figure
`matplotlib.figure.Figure` object
axes: array
Axes with the panels
"""
log_xticks = [np.log10(ax.get_xticks()[ax.get_xticks()>0])
for ax in axes.flatten()]
for ax in (axes[-1,:] if len(axes.shape)>1 else axes):
ax.set_xlabel(xlabel)
ax.set_xscale(xscale)
if xscale=='log':
for ax, xticks in zip(axes.flatten() if len(axes.shape)>1 else axes, log_xticks):
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.xaxis.set_minor_formatter(NullFormatter())
ax.set_xticks(10**xticks)
ax.set_xticklabels([f'${10**(t-int(t)):.0f}\\times 10^{{{np.floor(t):.0f}}}$'
for t in xticks], rotation=-45)
log_yticks = [np.log10(ax.get_yticks()[ax.get_yticks()>0])
for ax in axes.flatten()]
for ax in (axes[:,0] if len(axes.shape)>1 else axes[:1]):
ax.set_ylabel(ylabel)
ax.set_yscale(yscale)
if yscale=='log':
for ax, yticks in zip(axes.flatten() if len(axes.shape)>1 else axes, log_yticks):
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.yaxis.set_minor_formatter(NullFormatter())
ax.set_yticks(10**yticks)
ax.set_yticklabels([f'${10**(t-int(t)):.0f}\\times 10^{{{np.floor(t):.0f}}}$'
for t in yticks], rotation=-45)
return
def _set_label_format(kwargs, label_format_key, label_fmt_key, log,
default_fmt='.2f'):
"""
Set function for label formatting from dictionary and removes label_fmt_key.
Parameters
----------
kwargs: dict
Dictionary with the input values
label_format_key: str
Name of the format function entry
label_fmt_key: str
Name of entry with format of values (ex: '.2f').
It is only used if label_format_key not in kwargs.
log: bool
Format labels with 10^log10(val) format.
It is only used if label_format_key not in kwargs.
default_fmt: str
Format of linear values (ex: '.2f') when (label_format_key, label_fmt_key) not in kwargs.
Returns
-------
function
Label format function
"""
label_fmt = kwargs.pop(label_fmt_key, default_fmt)
kwargs[label_format_key] = kwargs.get(label_format_key,
lambda v: f'10^{{%{label_fmt}}}'%np.log10(v) if log else f'%{label_fmt}'%v)
[docs]def plot_histograms(
histogram, edges1, edges2, ax, shape='steps',
plt_kwargs=None, lines_kwargs_list=None,
add_legend=True, legend_format=lambda v: v, legend_kwargs=None):
"""
Plot recovery rate as lines, with each line binned by bins1 inside a bin of bins2.
Parameters
----------
histogram: array
Histogram 2D with dimention (edges2, edges1).
edges1, edges2: array
Edges of histogram.
ax: matplotlib.axes
Ax to add plot
shape: str
Shape of the lines. Can be steps or line.
plt_kwargs: dict, None
Additional arguments for pylab.plot
lines_kwargs_list: list, None
List of additional arguments for plotting each line (using pylab.plot).
Must have same size as len(bins2)-1
add_legend: bool
Add legend of bins
legend_format: function
Function to format the values of the bins in legend
legend_kwargs: dict, None
Additional arguments for pylab.legend
Returns
-------
ax
"""
add_grid(ax)
for hist_line, l_kwargs, edges in zip(
histogram, none_val(lines_kwargs_list, iter(lambda: {}, 1)),
zip(edges2, edges2[1:]),
):
kwargs = updated_dict(
{'label': get_bin_label(*edges, legend_format) if add_legend else None},
plt_kwargs, l_kwargs)
plot_hist_line(hist_line, edges1, ax, shape, **kwargs)
if add_legend:
ax.legend(**updated_dict(legend_kwargs))
return ax
[docs]def plot_healpix_map(healpix_map, nest=True, auto_lim=False, bad_val=None,
ra_lim=None, dec_lim=None, fig=None, figsize=None, **kwargs):
"""
Plot healpix map.
Parameters
----------
healpix_map: numpy array
Healpix map (must be 12*(2*n)**2 size).
nest: bool
If ordering is nested
auto_lim: bool
Set automatic limits for ra/dec, requires bad_val.
bad_val: float, None
Values for pixels outside footprint.
ra_lim: None, list
Min/max RA for plot.
dec_lim: None, list
Min/max DEC for plot.
fig: matplotlib.figure.Figure, None
Matplotlib figure object. If not provided a new one is created.
figsize: tuple
Width, height in inches (float, float). Default value from hp.cartview.
**kwargs:
Extra arguments for hp.cartview:
* xsize (int) : The size of the image. Default: 800
* title (str) : The title of the plot. Default: None
* min (float) : The minimum range value
* max (float) : The maximum range value
* remove_dip (bool) : If :const:`True`, remove the dipole+monopole
* remove_mono (bool) : If :const:`True`, remove the monopole
* gal_cut (float, scalar) : Symmetric galactic cut for \
the dipole/monopole fit. Removes points in latitude range \
[-gal_cut, +gal_cut]
* format (str) : The format of the scale label. Default: '%g'
* cbar (bool) : Display the colorbar. Default: True
* notext (bool) : If True, no text is printed around the map
* norm ({'hist', 'log', None}) : Color normalization, \
hist= histogram equalized color mapping, log= logarithmic color \
mapping, default: None (linear color mapping)
* cmap (a color map) : The colormap to use (see matplotlib.cm)
* badcolor (str) : Color to use to plot bad values
* bgcolor (str) : Color to use for background
* margins (None or sequence) : Either None, or a \
sequence (left,bottom,right,top) giving the margins on \
left,bottom,right and top of the axes. Values are relative to \
figure (0-1). Default: None
Returns
-------
fig: matplotlib.pyplot.figure
Figure of the plot.
ax: matplotlib.axes
Ax to add plot
cb: matplotlib.pyplot.colorbar, None
Colorbar
"""
nside = hp.npix2nside(len(healpix_map))
kwargs_ = updated_dict({'flip':'geo', 'title':None, 'cbar':True, 'nest':nest}, kwargs)
if auto_lim:
ra, dec = hp.pix2ang(nside, np.arange(len(healpix_map))[(healpix_map!=bad_val)
*~np.isnan(healpix_map)],
nest=nest, lonlat=True)
if ra.min()<180. and ra.max()>180.:
gap_ra = 360.-(ra.max()-ra.min())
gap_ra2 = ra[ra>180.].min()-ra[ra<180.].max()
if gap_ra2>gap_ra:
ra[ra>180.] -= 360.
edge = 2*(hp.nside2resol(nside, arcmin=True)/60)
kwargs_['lonra'] = [max(-360, ra.min()-edge), min(360, ra.max()+edge)]
kwargs_['latra'] = [max(-90, dec.min()-edge), min(90, dec.max()+edge)]
kwargs_['lonra'] = ra_lim if ra_lim else kwargs_.get('lonra')
kwargs_['latra'] = dec_lim if dec_lim else kwargs_.get('latra')
if (kwargs_['lonra'] is None)!=(kwargs_['latra'] is None):
raise ValueError('When auto_lim=False, ra_lim and dec_lim must be provided together.')
if fig is None:
fig = plt.figure()
hp.cartview(healpix_map, hold=True, **kwargs_)
ax = fig.axes[-2 if kwargs_['cbar'] else -1]
ax.axis('on')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
if figsize:
ax.set_aspect('auto')
fig.set_size_inches(figsize)
cb = None
if kwargs_['cbar']:
cb = fig.axes[-1]
ax.set_xlim(xlim)
ax.set_ylim(ylim)
xticks = ax.get_xticks()
xticks[xticks>=360] -= 360
if all(int(i)==i for i in xticks):
xticks = np.array(xticks, dtype=int)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(xticks)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('RA')
ax.set_ylabel('DEC')
return fig, ax, cb