# Source code for lightkurve.correctors.regressioncorrector

"""Defines RegressionCorrector to solve large linear regression problems
with user-defined Gaussian priors in a fast, analytical way.
"""
import logging
import warnings

from astropy.stats import sigma_clip
from astropy import units as u
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import issparse, csr_matrix

from .corrector import Corrector
from .designmatrix import (
DesignMatrix,
DesignMatrixCollection,
SparseDesignMatrix,
SparseDesignMatrixCollection,
)
from ..lightcurve import LightCurve, MPLSTYLE

__all__ = ["RegressionCorrector"]

log = logging.getLogger(__name__)

[docs]class RegressionCorrector(Corrector):
"""Remove noise using linear regression against a .DesignMatrix.

.. math::

\\newcommand{\\y}{\\mathbf{y}}
\\newcommand{\\cov}{\\boldsymbol\Sigma_\y}
\\newcommand{\\w}{\\mathbf{w}}
\\newcommand{\\covw}{\\boldsymbol\Sigma_\w}
\\newcommand{\\muw}{\\boldsymbol\mu_\w}
\\newcommand{\\sigw}{\\boldsymbol\sigma_\w}
\\newcommand{\\varw}{\\boldsymbol\sigma^2_\w}

Given a column vector of data :math:\y
and a design matrix of regressors :math:X,
we will find the vector of coefficients :math:\w
such that:

.. math::

\mathbf{y} = X\mathbf{w} + \mathrm{noise}

We will assume that the model fits the data within Gaussian uncertainties:

.. math::

p(\y | \w) = \mathcal{N}(X\w, \cov)

We make the regression robust by placing Gaussian priors on :math:\w:

.. math::

p(\w) = \mathcal{N}(\muw, \sigw)

We can then find the maximum likelihood solution of the posterior
distribution :math:p(\w | \y) \propto p(\y | \w) p(\w) by solving
the matrix equation:

.. math::

\w = \covw (X^\\top \cov^{-1} \y + \\boldsymbol\sigma^{-2}_\w I \muw)

Where :math:\covw is the covariance matrix of the coefficients:

.. math::

\covw^{-1} = (X^\\top \cov^{-1} X + \\boldsymbol\sigma^{-2}_\w I)

Parameters
----------
lc : .LightCurve
The light curve that needs to be corrected.
"""

[docs]    def __init__(self, lc):
# We don't accept NaN in time or flux.
if np.any([~np.isfinite(lc.time.value), ~np.isfinite(lc.flux)]):
raise ValueError(
"Input light curve has NaNs in time or flux. "
"Please remove NaNs before correction "
"(e.g. using lc = lc.remove_nans())."
)
# We don't accept NaN in flux_err, unless all values are NaN.
if np.any(~np.isfinite(lc.flux_err)) and not np.all(~np.isfinite(lc.flux_err)):
raise ValueError(
"Input light curve has NaNs in flux_err. "
"Please remove NaNs before correction "
"(e.g. using lc = lc.remove_nans())."
)
if np.any(lc.flux_err[np.isfinite(lc.flux_err)] <= 0):
raise ValueError(
"Input light curve contains flux uncertainties "
"smaller than or equal to zero. Please remove "
"these (e.g. using lc = lc[lc.flux_err > 0])."
)
self.lc = lc

# The following properties will be set when correct() is called.
# We're setting them here so they do not throw value errors
self.design_matrix_collection = None
self.coefficients = None
self.corrected_lc = None
self.model_lc = None
self.diagnostic_lightcurves = None

def __repr__(self):
return "RegressionCorrector (ID: {})".format(self.lc.targetid)

@property
def dmc(self):
"""Shorthand for self.design_matrix_collection."""
return self.design_matrix_collection

def _fit_coefficients(
):
"""Fit the linear regression coefficients.

This function will solve a linear regression with Gaussian priors
on the coefficients.

Parameters
----------

Returns
-------
coefficients : np.ndarray
The best fit model coefficients to the data.
"""

# If prior_mu is specified, prior_sigma must be specified
if not ((prior_mu is None) & (prior_sigma is None)) | (
(prior_mu is not None) & (prior_sigma is not None)
):
raise ValueError("Please specify both prior_mu and prior_sigma")

# If flux errors are not all finite numbers, then default to array of ones
if np.all(~np.isfinite(self.lc.flux_err.value)):
else:

# Retrieve the design matrix (X) as a numpy array
if isinstance(X, np.ndarray):
# Compute X^T cov^-1 X + 1/prior_sigma^2
sigma_w_inv = X.T.dot(X / flux_err[:, None] ** 2)
# Compute X^T cov^-1 y + prior_mu/prior_sigma^2

elif issparse(X):
sigma_f_inv = csr_matrix(1 / flux_err[:, None] ** 2)
# Compute X^T cov^-1 X + 1/prior_sigma^2
sigma_w_inv = X.T.dot(X.multiply(sigma_f_inv))
# Compute X^T cov^-1 y + prior_mu/prior_sigma^2
sigma_w_inv = sigma_w_inv.toarray()

if prior_sigma is not None:
sigma_w_inv = sigma_w_inv + np.diag(1.0 / prior_sigma ** 2)
if prior_sigma is not None:
B = B + (prior_mu / prior_sigma ** 2)

# Solve for weights w
w = np.linalg.solve(sigma_w_inv, B).T
if propagate_errors:
w_err = np.linalg.inv(sigma_w_inv)
else:
w_err = np.zeros(len(w)) * np.nan

return w, w_err

[docs]    def correct(
self,
design_matrix_collection,
sigma=5,
niters=5,
propagate_errors=False,
):
"""Find the best fit correction for the light curve.

Parameters
----------
design_matrix_collection : .DesignMatrix or .DesignMatrixCollection
One or more design matrices.  Each matrix must have a shape of
(time, regressors). The columns contained in each matrix must be
known to correlate with additive noise components we want to remove
from the light curve.
sigma : int (default 5)
Standard deviation at which to remove outliers from fitting
niters : int (default 5)
Number of iterations to fit and remove outliers
propagate_errors : bool (default False)
Whether to propagate the uncertainties from the regression. Default is False.
Setting to True will increase run time, but will sample from multivariate normal
distribution of weights.

Returns
-------
corrected_lc : .LightCurve
Corrected light curve, with noise removed.
"""
if not isinstance(design_matrix_collection, DesignMatrixCollection):
if isinstance(design_matrix_collection, SparseDesignMatrix):
design_matrix_collection = SparseDesignMatrixCollection(
[design_matrix_collection]
)
elif isinstance(design_matrix_collection, DesignMatrix):
design_matrix_collection = DesignMatrixCollection(
[design_matrix_collection]
)

# Validate the design matrix. Emits a warning if the matrix has low rank.
design_matrix_collection.validate()
self.design_matrix_collection = design_matrix_collection

else:

# Create an outlier mask using iterative sigma clipping
for count in range(niters):
coefficients, coefficients_err = self._fit_coefficients(
prior_mu=self.dmc.prior_mu,
prior_sigma=self.dmc.prior_sigma,
propagate_errors=propagate_errors,
)
)
model = u.Quantity(model, unit=self.lc.flux.unit)
residuals = self.lc.flux - model
log.debug(
"correct(): iteration {}: clipped {} cadences"
)

self.coefficients = coefficients
self.coefficients_err = coefficients_err

model_flux = self.dmc.X.dot(coefficients)
model_flux -= np.median(model_flux)
if propagate_errors:
with warnings.catch_warnings():
# ignore "RuntimeWarning: covariance is not symmetric positive-semidefinite."
warnings.simplefilter("ignore", RuntimeWarning)
samples = np.asarray(
[
self.dmc.X.dot(
np.random.multivariate_normal(
coefficients, coefficients_err
)
)
for idx in range(100)
]
).T
model_err = np.abs(
np.percentile(samples, [16, 84], axis=1)
- np.median(samples, axis=1)[:, None].T
).mean(axis=0)
else:
model_err = np.zeros(len(model_flux))
self.model_lc = LightCurve(
time=self.lc.time,
flux=u.Quantity(model_flux, unit=self.lc.flux.unit),
flux_err=u.Quantity(model_err, unit=self.lc.flux.unit),
)
self.corrected_lc = self.lc.copy()
self.corrected_lc.flux = self.lc.flux - self.model_lc.flux
self.corrected_lc.flux_err = (self.lc.flux_err ** 2 + model_err ** 2) ** 0.5
self.diagnostic_lightcurves = self._create_diagnostic_lightcurves()
return self.corrected_lc

def _create_diagnostic_lightcurves(self):
"""Returns a dictionary containing all diagnostic light curves.

The dictionary will provide a light curve for each matrix in the
design matrix collection.
"""
if self.coefficients is None:
raise ValueError("you need to call correct() first")

lcs = {}
for idx, submatrix in enumerate(self.dmc.matrices):
# What is the index of the first column for the submatrix?
firstcol_idx = sum([m.shape for m in self.dmc.matrices[:idx]])
submatrix_coefficients = self.coefficients[
firstcol_idx : firstcol_idx + submatrix.shape
]
# submatrix_coefficients_err = self.coefficients_err[firstcol_idx:firstcol_idx+submatrix.shape, firstcol_idx:firstcol_idx+submatrix.shape]
# samples = np.asarray([np.dot(submatrix.values, np.random.multivariate_normal(submatrix_coefficients, submatrix_coefficients_err)) for idx in range(100)]).T
# model_err = np.abs(np.percentile(samples, [16, 84], axis=1) - np.median(samples, axis=1)[:, None].T).mean(axis=0)
model_flux = u.Quantity(
submatrix.X.dot(submatrix_coefficients), unit=self.lc.flux.unit
)
model_flux_err = u.Quantity(
np.zeros(len(model_flux)), unit=self.lc.flux.unit
)
lcs[submatrix.name] = LightCurve(
time=self.lc.time,
flux=model_flux,
flux_err=model_flux_err,
label=submatrix.name,
)
return lcs

def _diagnostic_plot(self):
"""Produce diagnostic plots to assess the effectiveness of the correction.

Note: We need a hidden function so that other correctors can alter the plot.
"""
if not hasattr(self, "corrected_lc"):
raise ValueError(
"Please call the correct() method before trying to diagnose."
)

with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(2, figsize=(10, 6), sharex=True)
ax = axs
self.lc.plot(ax=ax, normalize=False, label="original", alpha=0.4)
for key in self.diagnostic_lightcurves.keys():
(
self.diagnostic_lightcurves[key]
- np.median(self.diagnostic_lightcurves[key].flux)
+ np.median(self.lc.flux)
).plot(ax=ax)
ax.set_xlabel("")
ax = axs
self.lc.plot(ax=ax, normalize=False, alpha=0.2, label="original")
normalize=False, c="r", marker="x", s=10, label="outlier_mask", ax=ax
)
normalize=False,
c="dodgerblue",
marker="x",
s=10,
ax=ax,
)
self.corrected_lc.plot(normalize=False, label="corrected", ax=ax, c="k")
return axs

[docs]    def diagnose(self):
"""Returns diagnostic plots to assess the most recent call to correct().

If correct() has not yet been called, a ValueError will be raised.

Returns
-------
~matplotlib.axes.Axes
The matplotlib axes object.
"""
return self._diagnostic_plot()

def diagnose_priors(self):
"""Returns a diagnostic plot visualizing how the best-fit coefficients
compare against the priors.

The method will show the results obtained during the most recent call
to correct().  If correct() has not yet been called, a
ValueError will be raised.

Returns
-------
~matplotlib.axes.Axes
The matplotlib axes object.
"""
if not hasattr(self, "corrected_lc"):
raise ValueError(
"Please call the correct() method before trying to diagnose."
)

names = [dm.name for dm in self.dmc]
with plt.style.context(MPLSTYLE):
_, axs = plt.subplots(
1, len(names), figsize=(len(names) * 4, 4), sharey=True
)
if not hasattr(axs, "__iter__"):
axs = [axs]
for idx, ax, X in zip(range(len(names)), axs, self.dmc):
X.plot_priors(ax=ax)
firstcol_idx = sum([m.shape for m in self.dmc.matrices[:idx]])
submatrix_coefficients = self.coefficients[
firstcol_idx : firstcol_idx + X.shape
]
[ax.axvline(s, color="red", zorder=-1) for s in submatrix_coefficients]
return axs