"""Defines the `SFFCorrector` class.
`SFFCorrector` enables systematics to be removed from light curves using the
Self Flat-Fielding (SFF) method described in Vanderburg and Johnson (2014).
"""
import logging
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astropy.modeling import models, fitting
from astropy.units import Quantity
from . import DesignMatrix, DesignMatrixCollection, SparseDesignMatrixCollection
from .regressioncorrector import RegressionCorrector
from .designmatrix import create_spline_matrix, create_sparse_spline_matrix
from .. import MPLSTYLE
from ..utils import LightkurveWarning
log = logging.getLogger(__name__)
__all__ = ["SFFCorrector"]
[docs]class SFFCorrector(RegressionCorrector):
"""Special case of `.RegressionCorrector` where the `.DesignMatrix` includes
the target's centroid positions.
The design matrix also contains columns representing a spline in time
design to capture the intrinsic, long-term variability of the target.
Parameters
----------
lc : `.LightCurve`
The light curve that needs to be corrected.
"""
[docs] def __init__(self, lc):
if getattr(lc, "mission", "") == "TESS":
warnings.warn(
"The SFF correction method is not suitable for use "
"with TESS data, because the spacecraft motion does "
"not proceed along a consistent arc.",
LightkurveWarning,
)
self.raw_lc = lc
if lc.flux.unit.to_string() == "":
lc = lc.copy()
else:
lc = lc.copy().normalize()
# Setting these values as None so we don't get a value error if the
# user tries to access them before "correct()"
self.window_points = None
self.windows = None
self.bins = None
self.timescale = None
self.breakindex = None
self.centroid_col = None
self.centroid_row = None
super(SFFCorrector, self).__init__(lc=lc)
def __repr__(self):
return "SFFCorrector (LC: {})".format(self.lc.meta.get("TARGETID"))
[docs] def correct(
self,
centroid_col=None,
centroid_row=None,
windows=20,
bins=5,
timescale=1.5,
breakindex=None,
degree=3,
restore_trend=False,
additional_design_matrix=None,
polyorder=None,
sparse=False,
**kwargs
):
"""Find the best fit correction for the light curve.
Parameters
----------
centroid_col : np.ndarray of floats (optional)
Array of centroid column positions. If ``None``, will use the
`centroid_col` attribute of the input light curve by default.
centroid_row : np.ndarray of floats (optional)
Array of centroid row positions. If ``None``, will use the
`centroid_row` attribute of the input light curve by default.
windows : int
Number of windows to split the data into to perform the correction.
Default 20.
bins : int
Number of "knots" to place on the arclength spline. More bins will
increase the number of knots, making the spline smoother in arclength.
Default 10.
timescale: float
Time scale of the b-spline fit to the light curve in time, in units
of input light curve time.
breakindex : None, int or list of ints (optional)
Optionally the user can break the light curve into sections. Set
break index to either an index at which to break, or list of indicies.
degree : int
The degree of polynomials in the splines in time and arclength. Higher
values will create smoother splines. Default 3.
cadence_mask : np.ndarray of bools (optional)
Mask, where True indicates a cadence that should be used.
sigma : int (default 5)
Standard deviation at which to remove outliers from fitting
niters : int (default 5)
Number of iterations to fit and remove outliers
restore_trend : bool (default False)
Whether to restore the long term spline trend to the light curve
propagate_errors : bool (default False)
Whether to propagate the uncertainties from the regression. Default is False.
Setting to True will increase run time, but will sample from multivariate normal
distribution of weights.
additional_design_matrix : `~lightkurve.lightcurve.Correctors.DesignMatrix` (optional)
Additional design matrix to remove, e.g. containing background vectors.
polyorder : int
Deprecated as of Lightkurve v1.4. Use ``degree`` instead.
Returns
-------
corrected_lc : `~lightkurve.lightcurve.LightCurve`
Corrected light curve, with noise removed.
"""
DMC, spline = DesignMatrixCollection, create_spline_matrix
if sparse:
DMC, spline = SparseDesignMatrixCollection, create_sparse_spline_matrix
if polyorder is not None:
warnings.warn(
"`polyorder` is deprecated and no longer used, "
"please use the `degree` keyword instead.",
LightkurveWarning,
)
if centroid_col is None:
self.lc = self.lc.remove_nans(column="centroid_col")
centroid_col = self.lc.centroid_col
if centroid_row is None:
self.lc = self.lc.remove_nans(column="centroid_row")
centroid_row = self.lc.centroid_row
if np.any([~np.isfinite(centroid_row), ~np.isfinite(centroid_col)]):
raise ValueError("Centroids contain NaN values.")
self.window_points = _get_window_points(
centroid_col, centroid_row, windows, breakindex=breakindex
)
self.windows = windows
self.bins = bins
self.timescale = timescale
self.breakindex = breakindex
self.arclength = _estimate_arclength(centroid_col, centroid_row)
lower_idx = np.asarray(np.append(0, self.window_points), int)
upper_idx = np.asarray(np.append(self.window_points, len(self.lc.time)), int)
dms = []
for idx, a, b in zip(range(len(lower_idx)), lower_idx, upper_idx):
if isinstance(self.arclength, Quantity):
ar = np.copy(self.arclength.value)
else:
ar = np.copy(self.arclength)
# Temporary workaround for issue #1161: AstroPy v5.0
# Masked arrays cannot be passed to `np.in1d` below
if hasattr(self.arclength, 'mask'):
ar = ar.unmasked
knots = list(np.percentile(ar[a:b], np.linspace(0, 100, bins + 1)[1:-1]))
ar[~np.in1d(ar, ar[a:b])] = 0
dm = spline(ar, knots=knots, degree=degree).copy()
dm.columns = [
"window{}_bin{}".format(idx + 1, jdx + 1) for jdx in range(dm.shape[1])
]
# I'm putting VERY weak priors on the SFF motion vectors
# (1e-6 is being added to prevent sigma from being zero)
ps = np.ones(dm.shape[1]) * 10000 * self.lc[a:b].flux.std() + 1e-6
dm.prior_sigma = ps
dms.append(dm)
sff_dm = DMC(dms).to_designmatrix(name="sff") # .standardize()
# long term
n_knots = int((self.lc.time.value[-1] - self.lc.time.value[0]) / timescale)
s_dm = spline(self.lc.time.value, n_knots=n_knots, name="spline")
means = [np.average(chunk) for chunk in np.array_split(self.lc.flux, n_knots)]
# means = [np.average(self.lc.flux, weights=s_dm.values[:, idx]) for idx in range(s_dm.shape[1])]
s_dm.prior_mu = np.asarray(means)
# I'm putting WEAK priors on the spline that it must be around 1
s_dm.prior_sigma = (
np.ones(len(s_dm.prior_mu)) * 1000 * self.lc.flux.std().value + 1e-6
)
# additional
if additional_design_matrix is not None:
if not isinstance(additional_design_matrix, DesignMatrix):
raise ValueError(
"`additional_design_matrix` must be a DesignMatrix object."
)
self.additional_design_matrix = additional_design_matrix
dm = DMC([s_dm, sff_dm, additional_design_matrix])
else:
dm = DMC([s_dm, sff_dm])
# correct
clc = super(SFFCorrector, self).correct(dm, **kwargs)
# clean
if restore_trend:
trend = self.diagnostic_lightcurves["spline"].flux
clc += trend - np.nanmedian(trend)
clc *= self.raw_lc.flux.mean()
return clc
[docs] def diagnose(self):
"""Returns a diagnostic plot which visualizes what happened during the
most recent call to `correct()`."""
axs = self._diagnostic_plot()
for t in self.window_points:
axs[0].axvline(self.lc.time.value[t], color="r", ls="--", alpha=0.3)
[docs] def diagnose_arclength(self):
"""Returns a diagnostic plot which visualizes arclength vs flux
from most recent call to `correct()`."""
max_plot = 5
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(
int(np.ceil(self.windows / max_plot)),
max_plot,
figsize=(10, int(np.ceil(self.windows / max_plot) * 2)),
sharex=True,
sharey=True,
)
axs = np.atleast_2d(axs)
axs[0, 2].set_title("Arclength Plot/Window")
plt.subplots_adjust(hspace=0, wspace=0)
lower_idx = np.asarray(np.append(0, self.window_points), int)
upper_idx = np.asarray(
np.append(self.window_points, len(self.lc.time)), int
)
if hasattr(self, "additional_design_matrix"):
name = self.additional_design_matrix.name
f = (
self.lc.flux
- self.diagnostic_lightcurves["spline"].flux
- self.diagnostic_lightcurves[name].flux
)
else:
f = self.lc.flux - self.diagnostic_lightcurves["spline"].flux
m = self.diagnostic_lightcurves["sff"].flux
idx, jdx = 0, 0
for a, b in zip(lower_idx, upper_idx):
ax = axs[idx, jdx]
if jdx == 0:
ax.set_ylabel("Flux")
ax.scatter(self.arclength[a:b], f[a:b], s=1, label="Data")
ax.scatter(
self.arclength[a:b][~self.cadence_mask[a:b]],
f[a:b][~self.cadence_mask[a:b]],
s=10,
marker="x",
c="r",
label="Outliers",
)
s = np.argsort(self.arclength[a:b])
ax.scatter(
self.arclength[a:b][s],
(m[a:b] - np.median(m[a:b]) + np.median(f[a:b]))[s],
c="C2",
s=0.5,
label="Model",
)
jdx += 1
if jdx >= max_plot:
jdx = 0
idx += 1
if b == len(self.lc.time):
ax.legend()
######################
# Helper functions #
######################
def _get_centroid_dm(col, row, name="centroids"):
"""Returns a `.DesignMatrix` containing (col, row) centroid positions
and transformations thereof.
Parameters
----------
col : np.ndarray
centroid column
row : np.ndarray
centroid row
name : str
Name to pass to `.DesignMatrix` (default: 'centroids').
Returns
-------
dm: np.ndarray
Design matrix with shape len(c) x 10
"""
data = [
col,
row,
col ** 2,
row ** 2,
col ** 3,
row ** 3,
col * row,
col ** 2 * row,
col * row ** 2,
col ** 2 * row ** 2,
]
names = [
r"col",
r"row",
r"col^2",
r"row^2",
r"col^3",
r"row^3",
r"col \times row",
r"col^2 \times row",
r"col \times row^2",
r"col^2 \times row^2",
]
df = pd.DataFrame(np.asarray(data).T, columns=names)
return DesignMatrix(df, name=name)
def _get_thruster_firings(arclength):
"""Find locations where K2 fired thrusters
Parameters
----------
arc : np.ndarray
arclength as a function of time
Returns
-------
thrusters: np.ndarray of bools
True at times where thrusters were fired.
"""
if isinstance(arclength, Quantity):
arc = np.copy(arclength.value)
else:
arc = np.copy(arclength)
# Rate of change of rate of change of arclength wrt time
d2adt2 = np.gradient(np.gradient(arc))
# Fit a Gaussian, most points lie in a tight region, thruster firings are outliers
g = models.Gaussian1D(amplitude=100, mean=0, stddev=0.01)
fitter = fitting.LevMarLSQFitter()
h = np.histogram(
d2adt2[np.isfinite(d2adt2)], np.arange(-0.5, 0.5, 0.0001), density=True
)
xbins = h[1][1:] - np.median(np.diff(h[1]))
g = fitter(g, xbins, h[0], weights=h[0] ** 0.5)
# Depending on the orientation of the roll, it is hard to return
# the point before the firing or the point after the firing.
# This makes sure we always return the same value, no matter the roll orientation.
def _start_and_end(start_or_end):
"""Find points at the start or end of a roll."""
if start_or_end == "start":
thrusters = (d2adt2 < (g.stddev * -5)) & np.isfinite(d2adt2)
if start_or_end == "end":
thrusters = (d2adt2 > (g.stddev * 5)) & np.isfinite(d2adt2)
# Pick the best thruster in each cluster
idx = np.array_split(
np.arange(len(thrusters)),
np.where(np.gradient(np.asarray(thrusters, int)) == 0)[0],
)
m = np.array_split(
thrusters, np.where(np.gradient(np.asarray(thrusters, int)) == 0)[0]
)
th = []
for jdx, _ in enumerate(idx):
if m[jdx].sum() == 0:
th.append(m[jdx])
else:
th.append(
(
np.abs(np.gradient(arc)[idx[jdx]])
== np.abs(np.gradient(arc)[idx[jdx]][m[jdx]]).max()
)
& m[jdx]
)
thrusters = np.hstack(th)
return thrusters
# Get the start and end points
thrusters = np.asarray([_start_and_end("start"), _start_and_end("end")])
thrusters = thrusters.any(axis=0)
# Take just the first point.
thrusters = (np.gradient(np.asarray(thrusters, int)) >= 0) & thrusters
return thrusters
def _get_window_points(
centroid_col, centroid_row, windows, arclength=None, breakindex=None
):
"""Returns indices where thrusters are fired.
Parameters
----------
lc : `.LightCurve` object
Input light curve
windows: int
Number of windows to split the light curve into
arc: np.ndarray
Arclength for the roll motion
breakindex: int
Cadence where there is a natural break. Windows will be automatically put here.
"""
if arclength is None:
arclength = _estimate_arclength(centroid_col, centroid_row)
# Validate break indices
if isinstance(breakindex, int):
breakindexes = [breakindex]
if breakindex is None:
breakindexes = []
elif (breakindex[0] == 0) & (len(breakindex) == 1):
breakindexes = []
else:
breakindexes = breakindex
if not isinstance(breakindexes, list):
raise ValueError("`breakindex` must be an int or a list")
# If the user asks for break indices we should still return them,
# even if there is only 1 window.
if windows == 1:
return breakindexes
# Find evenly spaced window points
dt = len(centroid_col) / windows
lower_idx = np.append(0, breakindexes)
upper_idx = np.append(breakindexes, len(centroid_col))
window_points = np.hstack(
[np.asarray(np.arange(a, b, dt), int) for a, b in zip(lower_idx, upper_idx)]
)
# Get thruster firings
thrusters = _get_thruster_firings(arclength)
for b in breakindexes:
thrusters[b] = True
thrusters = np.where(thrusters)[0]
# Find the nearest point to each thruster firing, unless it's a user supplied break point
if len(thrusters) > 0:
window_points = [
thrusters[np.argmin(np.abs(thrusters - wp))] + 1
for wp in window_points
if wp not in breakindexes
]
window_points = np.unique(np.hstack([window_points, breakindexes]))
# If the first or last windows are very short (<40% median window length),
# then we add them to the second or penultimate window, respectively,
# by removing their break points.
median_length = np.median(np.diff(window_points))
if window_points[0] < 0.4 * median_length:
window_points = window_points[1:]
if window_points[-1] > (len(centroid_col) - 0.4 * median_length):
window_points = window_points[:-1]
return np.asarray(window_points, dtype=int)
def _estimate_arclength(centroid_col, centroid_row):
"""Estimate the arclength given column and row centroid positions.
We use the approximation that the arclength equals
(row**2 + col**2)**0.5
For this to work, row and column must be correlated not anticorrelated.
"""
col = centroid_col - np.nanmin(centroid_col)
row = centroid_row - np.nanmin(centroid_row)
if np.all((col == 0) & (row == 0)):
raise RuntimeError("Arclength cannot be computed because there is no "
"centroid motion. Make sure that the aperture of "
"the TPF at least two pixels.")
# Force c to be correlated not anticorrelated
if np.polyfit(col.data, row.data, 1)[0] < 0:
col = np.nanmax(col) - col
arclength = (col ** 2 + row ** 2) ** 0.5
return arclength