Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def diagnose_arclength(self):
"""Returns a diagnostic plot which visualizes arclength vs flux
from most recent call to `correct()`."""
max_plot = 5
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(int(np.ceil(self.windows/max_plot)), max_plot,
figsize=(10, int(np.ceil(self.windows/max_plot)*2)),
sharex=True, sharey=True)
axs = np.atleast_2d(axs)
axs[0, 2].set_title('Arclength Plot/Window')
plt.subplots_adjust(hspace=0, wspace=0)
lower_idx = np.asarray(np.append(0, self.window_points), int)
upper_idx = np.asarray(np.append(self.window_points, len(self.lc.time)), int)
if hasattr(self, 'additional_design_matrix'):
name = self.additional_design_matrix.name
f = (self.lc.flux - self.diagnostic_lightcurves['spline'].flux
- self.diagnostic_lightcurves[name].flux)
else:
f = (self.lc.flux - self.diagnostic_lightcurves['spline'].flux)
'''Plot the CBVs for a given list of CBVs
Parameters
----------
cbvs : list of ints
The list of cotrending basis vectors to fit to the data. For example,
[1, 2] will fit the first two basis vectors.
ax : matplotlib.pyplot.Axes.AxesSubplot
Matplotlib axis object. If `None`, one will be generated.
Returns
-------
ax : matplotlib.pyplot.Axes.AxesSubplot
Matplotlib axis object
'''
with plt.style.context(MPLSTYLE):
clip = np.in1d(np.arange(1, len(self.cbv_array)+1), np.asarray(cbvs))
time_clip = np.in1d(self.cbv_cadenceno, self.lc.cadenceno)
if ax is None:
_, ax = plt.subplots(1)
for idx, cbv in enumerate(self.cbv_array[clip, :][:, time_clip]):
ax.plot(self.cbv_cadenceno[time_clip], cbv+idx/10., label='{}'.format(idx + 1))
ax.set_yticks([])
ax.set_xlabel('Time (MJD)')
module, output = channel_to_module_output(self.lc.channel)
if self.lc.mission == 'Kepler':
ax.set_title('Kepler CBVs (Module : {}, Output : {}, Quarter : {})'
''.format(module, output, self.lc.quarter))
elif self.lc.mission == 'K2':
ax.set_title('K2 CBVs (Module : {}, Output : {}, Campaign : {})'
''.format(module, output, self.lc.campaign))
'''Plot the CBVs for a given list of CBVs
Parameters
----------
cbvs : list of ints
The list of cotrending basis vectors to fit to the data. For example,
[1, 2] will fit the first two basis vectors.
ax : matplotlib.pyplot.Axes.AxesSubplot
Matplotlib axis object. If `None`, one will be generated.
Returns
-------
ax : matplotlib.pyplot.Axes.AxesSubplot
Matplotlib axis object
'''
with plt.style.context(MPLSTYLE):
clip = np.in1d(np.arange(1, len(self.cbv_array)+1), np.asarray(cbvs))
time_clip = np.in1d(self.cbv_cadenceno, self.lc.cadenceno)
if ax is None:
_, ax = plt.subplots(1)
for idx, cbv in enumerate(self.cbv_array[clip, :][:, time_clip]):
ax.plot(self.cbv_cadenceno[time_clip], cbv+idx/10., label='{}'.format(idx + 1))
ax.set_yticks([])
ax.set_xlabel('Time (MJD)')
module, output = channel_to_module_output(self.lc.channel)
if self.lc.mission == 'Kepler':
ax.set_title('Kepler CBVs (Module : {}, Output : {}, Quarter : {})'
''.format(module, output, self.lc.quarter))
elif self.lc.mission == 'K2':
ax.set_title('K2 CBVs (Module : {}, Output : {}, Campaign : {})'
''.format(module, output, self.lc.campaign))
def _create_plot(self, method='plot', ax=None, normalize=False,
xlabel=None, ylabel=None, title='', style='lightkurve',
show_colorbar=True, colorbar_label='',
**kwargs):
"""Implements `plot()`, `scatter()`, and `errorbar()` to avoid code duplication.
Returns
-------
ax : `~matplotlib.axes.Axes`
The matplotlib axes object.
"""
# Configure the default style
if style is None or style == 'lightkurve':
style = MPLSTYLE
# Default xlabel
if xlabel is None:
if self.time_format == 'bkjd':
xlabel = 'Time - 2454833 [BKJD days]'
elif self.time_format == 'btjd':
xlabel = 'Time - 2457000 [BTJD days]'
elif self.time_format == 'jd':
xlabel = 'Time [JD]'
else:
xlabel = 'Time'
# Default ylabel
if ylabel is None:
if normalize or (self.flux_unit == u.dimensionless_unscaled):
ylabel = 'Normalized Flux'
elif self.flux_unit is None:
ylabel = 'Flux'
Parameters
----------
ax : `~matplotlib.axes.Axes`
A matplotlib axes object to plot into. If no axes is provided,
a new one will be created.
**kwargs : dict
Dictionary of arguments to be passed to matplotlib's `~matplotlib.pyplot.plot`.
Returns
-------
ax : `~matplotlib.axes.Axes`
The matplotlib axes object.
"""
with plt.style.context(MPLSTYLE):
if ax is None:
_, ax = plt.subplots()
for kwarg in ['c', 'color', 'label', 'normalize']:
if kwarg in kwargs:
kwargs.pop(kwarg)
labels = np.asarray([lcf.label for lcf in self])
try:
unique_labels = np.sort(np.unique(labels))
except TypeError:
unique_labels = [None]
for idx, targetid in enumerate(unique_labels):
jdxs = np.where(labels == targetid)[0]
if not hasattr(jdxs, '__iter__'):
jdxs = [jdxs]
for jdx in jdxs:
NOTE: When plotting, we exclude the first two frequency lag bins, to
make the relevant features on the plot clearer, as these bins are close to
the spectrum correlated with itself and therefore much higher than the rest
of the bins.
Parameters
----------
deltanu : `.SeismologyResult` object
The object returned by `estimate_deltanu_acf2d()`.
Returns
-------
ax : `~matplotlib.axes.Axes`
The matplotlib axes object.
"""
with plt.style.context(MPLSTYLE):
fig, axs = plt.subplots(2, figsize=(8.485, 8))
ax = axs[0]
periodogram.plot(ax=ax, label='')
ax.axvline(deltanu.diagnostics['numax'].value, c='r', linewidth=1,
alpha=.4, ls=':')
ax.text(deltanu.diagnostics['numax'].value, periodogram.power.value.max()*0.45,
'{} ({:.1f} {})'.format(r'$\nu_{\rm max}$', deltanu.diagnostics['numax'].value, deltanu.diagnostics['numax'].unit.to_string('latex')),
rotation=90, ha='right', color='r', alpha=0.5, fontsize=8)
ax.text(.025, .9, 'Input Power Spectrum', horizontalalignment='left',
transform=ax.transAxes, fontsize=11)
window_width = 2*int(np.floor(utils.get_fwhm(periodogram, deltanu.diagnostics['numax'].value)))
frequency_spacing = np.median(np.diff(periodogram.frequency.value))
spread = int(window_width/2/frequency_spacing) # spread in indices
# The exact end point is therefore the ncolumns*nrows away from the start
end = start + n_columns*n_rows
ep = np.reshape(pp[start : end], (n_rows, n_columns))
if scale=='log':
ep = np.log10(ep)
#Reshape the freq into n_rowss of n_columnss & create arays
ef = np.reshape(ff[start : end], (n_rows, n_columns))
x_f = ((ef[0,:]-ef[0,0]) % deltanu)
y_f = (ef[:,0])
#Plot the echelle diagram
with plt.style.context(MPLSTYLE):
if ax is None:
fig, ax = plt.subplots()
extent = (x_f[0].value, x_f[-1].value, y_f[0].value, y_f[-1].value)
figsize = plt.rcParams['figure.figsize']
a = figsize[1] / figsize[0]
b = (extent[3] - extent[2]) / (extent[1] - extent[0])
vmin = np.nanpercentile(ep.value, 1)
vmax = np.nanpercentile(ep.value, 99)
im = ax.imshow(ep.value, cmap=cmap, aspect=a/b, origin='lower',
extent=extent, vmin=vmin, vmax=vmax)
cbar = plt.colorbar(im, ax=ax, extend='both', pad=.01)
if isinstance(self.periodogram, SNRPeriodogram):
def _create_plot(self, method='plot', flux_types=None, style='lightkurve',
**kwargs):
"""Implements `plot()`, `scatter()`, and `errorbar()` to avoid code duplication.
Returns
-------
ax : `~matplotlib.axes.Axes`
The matplotlib Axes object.
"""
if style is None or style == 'lightkurve':
style = MPLSTYLE
with plt.style.context(style):
if not ('ax' in kwargs):
fig, ax = plt.subplots(1)
kwargs['ax'] = ax
if flux_types is None:
flux_types = self._flux_types()
if isinstance(flux_types, str):
flux_types = [flux_types]
for idx, ft in enumerate(flux_types):
lc = self.get_lightcurve(ft)
kwargs['color'] = np.asarray(mpl.rcParams['axes.prop_cycle'])[idx]['color']
if method == 'plot':
lc.plot(label=ft, **kwargs)
elif method == 'scatter':
lc.scatter(label=ft, **kwargs)
elif method == 'errorbar':
-------
ax : `~matplotlib.axes.Axes`
The matplotlib axes object.
"""
if isinstance(unit, u.quantity.Quantity):
unit = unit.unit
view = self._validate_view(view)
if unit is None:
unit = self.frequency.unit
if view == 'period':
unit = self.period.unit
if style is None or style == 'lightkurve':
style = MPLSTYLE
if ylabel is None:
ylabel = "Power"
if self.power.unit.to_string() != '':
unit_label = self.power.unit.to_string('latex')
# The line below is a workaround for AstroPy bug #9218.
# It can be removed once the fix for that issue is widespread.
# See https://github.com/astropy/astropy/pull/9218
unit_label = re.sub(r"\^{([^}]+)}\^{([^}]+)}", r"^{\g<1>^{\g<2>}}", unit_label)
ylabel += " [{}]".format(unit_label)
# This will need to be fixed with housekeeping. Self.label currently doesnt exist.
if ('label' not in kwargs) and ('label' in dir(self)):
kwargs['label'] = self.label
with plt.style.context(style):
if ax is None:
mask_color : str
Color to show the aperture mask
style : str
Path or URL to a matplotlib style file, or name of one of
matplotlib's built-in stylesheets (e.g. 'ggplot').
Lightkurve's custom stylesheet is used by default.
kwargs : dict
Keywords arguments passed to `lightkurve.utils.plot_image`.
Returns
-------
ax : `~matplotlib.axes.Axes`
The matplotlib axes object.
"""
if style == 'lightkurve' or style is None:
style = MPLSTYLE
if cadenceno is not None:
try:
frame = np.argwhere(cadenceno == self.cadenceno)[0][0]
except IndexError:
raise ValueError("cadenceno {} is out of bounds, "
"must be in the range {}-{}.".format(
cadenceno, self.cadenceno[0], self.cadenceno[-1]))
try:
if bkg and np.any(np.isfinite(self.flux_bkg[frame])):
pflux = self.flux[frame] + self.flux_bkg[frame]
else:
pflux = self.flux[frame]
except IndexError:
raise ValueError("frame {} is out of bounds, must be in the range "
"0-{}.".format(frame, self.shape[0]))
with plt.style.context(style):