# External imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from scipy.signal import welch
# Medusa imports
from medusa.plots.head_plots import TopographicPlot
from medusa.meeg.meeg import EEGChannelSet
from medusa import epoching
from medusa.plots.timeplot import time_plot
from medusa.components import SerializableComponent, ProcessingMethod
[docs]def reject_noisy_epochs(epochs, signal_mean, signal_std, k=4, n_samp=2,
n_cha=1):
"""Simple thresholding method to reject noisy epochs. It discards epochs
with n_samp samples greater than k*std in n_cha channels
Parameters
----------
epochs : list or numpy.ndarray
Epochs of signal with dimensions [n_epochs x samples x channels]
signal_mean : float
Mean of the signal
signal_std : float
Standard deviation of the signal
k : float
Standard deviation multiplier to calculate threshold
n_samp : int
Minimum number of samples that have to be over the threshold in each
epoch to be discarded
n_cha : int
Minimum number of channels that have to have n_samples over the
threshold in each epoch to be discarded
Returns
-------
float
Percentage of reject epochs in
numpy.ndarray
Clean epochs
numpy.ndarray
Indexes for rejected epochs. True for discarded epoch
"""
# Check errors
if len(epochs.shape) != 3:
raise Exception('Malformed epochs array. It must be of dimmensions '
'[epochs x samples x channels]')
if signal_std.shape[0] != epochs.shape[2]:
raise Exception('Array signal_std does not match with epochs size. '
'It must have the same number of channels')
if signal_mean.shape[0] != epochs.shape[2]:
raise Exception('Array signal_mean does not match with epochs size. '
'It must have the same number of channels')
epochs_abs = np.abs(epochs)
cmp = epochs_abs > np.abs(signal_mean) + k * signal_std
idx = np.sum((np.sum(cmp, axis=1) >= n_samp), axis=1) >= n_cha
pct_rejected = (np.sum(idx) / epochs.shape[0]) * 100
return pct_rejected, epochs[~idx, :, :], idx
[docs]class ICA:
[docs] def __init__(self, random_state=None):
self.n_components = None
self.random_state = random_state
self.pre_whitener = None
self.unmixing_matrix = None
self.mixing_matrix = None
self._ica_n_iter = None
self.components_excluded = None
self.ica_labels = None
# PCA attributes
self.n_pca_components = None
self._pca_mean = None
self._pca_explained_variance = None
self._pca_explained_variance_ratio = None
self._pca_components = None
# Signal attributes
self.l_cha = None
self.channel_set = None
self.fs = None
[docs] def fit(self, signal, l_cha, fs, n_components):
from sklearn.decomposition import FastICA
from scipy import linalg
# Set arguments
self.n_components = n_components
self.l_cha = l_cha
self.fs = fs
# Define MEEG Channel Set
self.channel_set = EEGChannelSet()
self.channel_set.set_standard_montage(self.l_cha)
signal, _ = self._check_signal_dimensions(signal.copy())
self._get_pre_whitener(signal)
signal_pca = self._whitener(signal.copy())
ica = FastICA(whiten=False, random_state=self.random_state,
max_iter=1000)
ica.fit(signal_pca[:, :self.n_components])
self.unmixing_matrix = ica.components_
self._ica_n_iter = ica.n_iter_
assert self.unmixing_matrix.shape == (self.n_components,) * 2
# Whitening unmixing matrix
norm = np.sqrt(self._pca_explained_variance[:self.n_components])
norm[norm == 0] = 1.
self.unmixing_matrix /= norm
self.mixing_matrix = linalg.pinv(self.unmixing_matrix)
# Sort ica components from greater to lower explained variance
self._sort_components(signal)
self.ica_labels = [f"ICA_{n}" for n in range(self.n_components)]
[docs] def get_sources(self, signal):
if not hasattr(self, 'mixing_matrix'):
raise RuntimeError('ICA has not been fitted yet. Please, fit ICA.')
signal, _ = self._check_signal_dimensions(signal.copy())
signal = self._pre_whiten(signal)
if self._pca_mean is not None:
signal -= self._pca_mean
# Transform signal to PCA space and then apply unmixing matrix
pca_transform = np.dot(self._pca_components[:self.n_components],
signal.T)
sources = np.dot(self.unmixing_matrix, pca_transform)
return sources.T
[docs] def rebuild(self, signal, exclude=None):
signal, n_epo_samples = self._check_signal_dimensions(signal.copy())
signal = self._pre_whiten(signal)
if exclude is not None:
if isinstance(exclude, int):
exclude = [exclude]
exclude = np.array(list(set(exclude)))
if len(np.where(exclude > self.n_components)[0]) > 0:
raise ValueError("One or more ICA component keys that you have"
"marked to exclude from signal rebuild are "
"greater than the total number of "
"ICA components.")
self.components_excluded = exclude
# Apply PCA
if self._pca_mean is not None:
signal -= self._pca_mean
# Determine ica components to keep in signal rebuild
c_to_keep = np.setdiff1d(np.arange(self.n_components), exclude)
# Define projection matrix
proj = np.dot(np.dot(self._pca_components[:self.n_components].T,
self.mixing_matrix[:, c_to_keep]),
np.dot(self.unmixing_matrix[c_to_keep, :],
self._pca_components[:self.n_components]))
# Apply projection to signal
signal_rebuilt = np.transpose(np.dot(proj, signal.T))
if self._pca_mean is not None:
signal_rebuilt += self._pca_mean
signal_rebuilt *= self.pre_whitener
# Restore epochs if original signal was divided in epochs
if n_epo_samples is not None:
signal_rebuilt = np.reshape(signal_rebuilt,
(int(signal_rebuilt.shape[
0] / n_epo_samples)
, n_epo_samples,
signal_rebuilt.shape[1]))
return signal_rebuilt
[docs] def get_components(self):
# Transform
components = np.dot(self.mixing_matrix[:, :self.n_components].T,
self._pca_components[:self.n_components]).T
return components
[docs] def save(self, path):
if not hasattr(self, 'mixing_matrix'):
raise RuntimeError('ICA has not been fitted yet. Please, fit ICA.')
ica_data = ICAData(self.pre_whitener, self.unmixing_matrix,
self.mixing_matrix, self.n_components,
self._pca_components, self._pca_mean,
self.components_excluded,
self.random_state)
ica_data.save(path, 'bson')
[docs] def load(self, path):
# Load ICAData instance
ica_data = ICAData().load_from_bson(path)
# Update ICA arguments
self.pre_whitener = ica_data.pre_whitener
self.unmixing_matrix = np.array(ica_data.unmixing_matrix)
self.mixing_matrix = np.array(ica_data.mixing_matrix)
self.n_components = ica_data.n_components
self._pca_components = np.array(ica_data.pca_components)
self._pca_mean = np.array(ica_data.pca_mean)
self.components_excluded = np.array(ica_data.components_excluded)
self.random_state = ica_data.random_state
# ------------------------- AUXILIARY METHODS ------------------------------
@staticmethod
def _check_signal_dimensions(signal):
n_epo_samples = None
# Check input dimensions
if len(signal.shape) == 3:
# Stack the epochs
n_epo_samples = signal.shape[1]
signal = np.vstack(signal)
elif len(signal.shape) == 1:
raise ValueError("signal input only has one dimension, but two"
"dimensions (n_samples, n_channels) arer nedded at"
"least")
return signal, n_epo_samples
def _get_pre_whitener(self, signal):
self.pre_whitener = np.std(signal)
def _pre_whiten(self, signal):
signal /= self.pre_whitener
return signal
def _whitener(self, signal):
from sklearn.decomposition import PCA
signal_pw = self._pre_whiten(signal)
pca = PCA(n_components=None, whiten=True)
signal_pca = pca.fit_transform(signal_pw)
self._pca_mean = pca.mean_
self._pca_components = pca.components_
self._pca_explained_variance = pca.explained_variance_
self._pca_explained_variance_ratio = pca.explained_variance_ratio_
self.n_pca_components = pca.n_components_
del pca
# Check a correct input of n_components parameter
if self.n_components is None:
self.n_components = min(self.n_pca_components,
self._exp_var_ncomp(
self._pca_explained_variance_ratio,
0.99999)[0])
elif isinstance(self.n_components, float):
self.n_components = self._exp_var_ncomp(
self._pca_explained_variance_ratio, self.n_components)[0]
if self.n_components == 1:
raise RuntimeError(
'One PCA component captures most of the '
'explained variance, your threshold '
'results in 1 component. You should select '
'a higher value.')
else:
if not isinstance(self.n_components, int):
raise ValueError(
f'n_components={self.n_components} must be None,'
f'float or int value')
if self.n_components > self.n_pca_components:
raise ValueError(f'The number of ICA components('
f'n_components={self.n_components}) must be '
f'lower than the number of PCA components'
f'({self.n_pca_components}).')
return signal_pca
def _sort_components(self, signal):
sources = self.get_sources(signal)
meanvar = np.sum(self.mixing_matrix ** 2, axis=0) * \
np.sum(sources ** 2, axis=0) / \
(sources.shape[0] * sources.shape[1] - 1)
c_order = np.argsort(meanvar)[::-1]
self.unmixing_matrix = self.unmixing_matrix[c_order, :]
self.mixing_matrix = self.mixing_matrix[:, c_order]
@staticmethod
def _exp_var_ncomp(var, n):
cvar = np.asarray(var, dtype=np.float64)
cvar = cvar.cumsum()
cvar /= cvar[-1]
# We allow 1., which would give us N+1
n = min((cvar <= n).sum() + 1, len(cvar))
return n, cvar[n - 1]
# --------------------------- PLOT METHODS ---------------------------------
[docs] def plot_components(self, cmap='bwr'):
# Get ICA components
components = self.get_components()
# Define subplot parameters
n_components = components.shape[1]
if n_components <= 5:
cols = n_components
rows = 1
else:
cols = 5
rows = np.ceil(n_components / 5)
# Define subplot
fig, axes = plt.subplots(int(rows), int(cols))
if len(axes.shape) == 1:
axes = np.array([axes])
# Topo plots
ic_c = 0
for r in axes:
for c in r:
if ic_c < n_components:
topo = TopographicPlot(axes=c,channel_set=self.channel_set,
interp_points=300,head_line_width=1.5,
cmap=cmap,extra_radius=0)
topo.update(values=components[:, ic_c])
c.set_title(self.ica_labels[ic_c])
ic_c += 1
else:
c.set_axis_off()
fig.show()
return fig
[docs] def plot_sources(self, signal, sources_to_show=None, time_to_show=None,
ch_offset=None):
sources = self.get_sources(signal)
if ch_offset is None:
ch_offset = np.max(np.abs(sources[:, 0]))
fig, ax = plt.subplots(1,1)
time_plot(sources, self.fs, self.ica_labels,
ch_to_show=sources_to_show,
time_to_show=time_to_show,
ch_offset=ch_offset,show=False,fig=fig,
axes=ax)
[docs] def plot_summary(self, signal, component, psd_freq_range=[1,70],
psd_window='hamming', time_to_show=2,cmap='bwr'):
# Check error
if isinstance(component,int):
component = np.array([component])
elif isinstance(component,list):
component = np.array(component)
else:
raise ValueError("Component parameter must be a int or list of int"
"of the ICA components.")
if np.any(component > self.n_components):
raise ValueError("There is a component greater than the total"
f"number of ICA components:{self.n_components}.")
# Check if signal is epoched
n_samples_epoch = None
if len(signal.shape) == 3:
n_samples_epoch = signal.shape[1]
n_stacks = signal.shape[0]
time_to_show = int(n_samples_epoch/self.fs)
sources = self.get_sources(signal)[:,component]
components = self.get_components()
if n_samples_epoch is None:
n_samples_epoch = int(time_to_show * self.fs)
n_stacks = (int(len(sources[:, 0]) / n_samples_epoch))
for ii in range(len(component)):
fig = plt.figure()
ax_1 = fig.add_subplot(3,4,(1,6))
ax_2 = fig.add_subplot(3, 4, (3,4))
ax_3 = fig.add_subplot(3, 4, (7,8))
ax_4 =fig.add_subplot(3, 1, 3)
# Topoplot
topo = TopographicPlot(axes=ax_1, channel_set=self.channel_set,
interp_points=300, head_line_width=1.5,
cmap=cmap, extra_radius=0)
topo.update(values=components[:, component[ii]])
ax_1.set_title(self.ica_labels[component[ii]])
stacked_source = np.reshape(
sources[:(n_stacks * n_samples_epoch), ii],
(n_stacks, n_samples_epoch))
# PSD
f, psd = welch(stacked_source, self.fs, window=psd_window, )
f_range = np.logical_and(f>=psd_freq_range[0],f<=psd_freq_range[1])
psd_mean = np.mean(10*np.log10(psd),axis=0)
psd_std = np.std(10*np.log10(psd),axis=0)
ax_2.fill_between(f[f_range], (psd_mean-psd_std)[f_range],
(psd_mean+psd_std)[f_range],
color='k',alpha=0.3)
ax_2.plot(f[f_range],psd_mean[f_range],'k')
ax_2.set_xlim(f[f_range][0],f[f_range][-1])
ax_2.set_xlabel('Frequency (Hz)')
ax_2.set_ylabel('Power/Hz (dB)')
ax_2.set_title('Power spectral density')
# Stacked data image
ax_3.pcolormesh(np.linspace(0,time_to_show,n_samples_epoch),
np.arange(n_stacks),stacked_source,cmap=cmap,
shading='gouraud')
ax_3.set_xlabel('Time (s)')
ax_3.set_ylabel('Segments')
ax_3.set_title('Stacked source segments')
# Time plot
time_plot(sources[:,ii],self.fs,[self.ica_labels[component[ii]]],
time_to_show=time_to_show,fig=fig,axes=ax_4)
ax_4.set_title('Source time plot')
fig.tight_layout(pad=1)
return fig
[docs] def show_exclusion(self, signal, exclude=None, ch_to_show=None,
time_to_show=None, ch_offset=None):
if ch_offset is None:
ch_offset = np.max(np.abs(signal.copy()))
# Check if signal is divided in epochs
if len(signal.shape) == 3:
n_epochs = signal.shape[0]
else:
n_epochs = 1
fig, ax = plt.subplots(1,1)
time_plot(signal,self.fs,self.l_cha,time_to_show,
ch_to_show,ch_offset,axes=ax,fig=fig)
signal_rebuilt = self.rebuild(signal,exclude)
time_plot(signal_rebuilt,self.fs,self.l_cha,time_to_show,
ch_to_show,ch_offset,fig=fig,axes=ax,color='b',
show_epoch_lines=False)
#Create legend
handles = [fig.axes[0].lines[0],fig.axes[0].lines[signal.shape[-1]
+ n_epochs -1]]
fig.axes[0].legend(handles=handles,labels=['Pre-ICA','Post-ICA'],
loc='upper center', bbox_to_anchor=(0.5, 1.15),
ncol=3, fancybox=True, shadow=True)
[docs]class ICAData(SerializableComponent):
[docs] def __init__(self, pre_whitener=None, unmixing_matrix=None,
mixing_matrix=None,
n_components=None, pca_components=None, pca_mean=None,
components_excluded=None, random_state=None):
# General parameters
self.n_components = n_components
self.random_state = random_state
self.pre_whitener = pre_whitener
# ICA
self.unmixing_matrix = unmixing_matrix
self.mixing_matrix = mixing_matrix
self.components_excluded = components_excluded
# PCA
self.pca_components = pca_components
self.pca_mean = pca_mean
[docs] def to_serializable_obj(self):
rec_dict = self.__dict__
for key in rec_dict.keys():
if type(rec_dict[key]) == np.ndarray:
rec_dict[key] = rec_dict[key].tolist()
return rec_dict
[docs] @classmethod
def from_serializable_obj(cls, dict_data):
return cls(**dict_data)
[docs]class ArtifactRegression(ProcessingMethod):
[docs] def __init__(self):
"""
Initialize the artifact regression method.
This class implements a method to remove artifacts (e.g., EOG signals)
from time series data (e.g., EEG signals). The artifacts must be
recorded simultaneously with the signal, and this method attempts to
regress out the artifacts from the main signal.
Attributes:
- self.coefs: Stores the regression coefficients after fitting the
model.
Notes:
The input signals (sig and art_sig) must be preprocessed (e.g.,
band-pass filtering) before using them in this class for better
performance and accuracy.
Alternative implementation:
self.coefs = np.linalg.lstsq(
art_sig.T @ art_sig, art_sig.T, rcond=None)[0] @ sig
sig = (sig.T - self.coefs.T @ art_sig.T).T
"""
# Calling the superclass constructor and defining the inputs that
# are transformed and fit by the method.
super().__init__(transform=['signal', 'artefacts_signal'],
fit_transform=['signal', 'artefacts_signal'])
# Initialize the coefficient matrix (will be filled after fitting)
self.coefs = None
self.is_fit = False
[docs] def fit(self, sig, art_sig):
"""
Fits the artifact regression model by computing regression coefficients
for removing the artifacts from the signal.
This method performs a linear regression for each signal channel to
estimate how much of the artifact is present in each channel. The
regression coefficients (self.coefs) are computed using the least
squares solution for each channel.
Steps
1. Remove the mean from the artifact signal for normalization.
2. Compute the covariance matrix of the artifact signal.
3. For each signal channel:
- Remove the mean from the signal.
- Perform a least squares fit to estimate the regression
coefficients.
Parameters
----------
sig : The main signal (e.g., EEG) that contains the artifacts.
art_sig : The artifact signal (e.g., EOG) recorded alongside the main
signal.
"""
# Mean-center the signal
art_sig = art_sig - np.mean(art_sig, axis=-1, keepdims=True)
# Compute covariance of artifact signal
cov_ref = art_sig.T @ art_sig
# Regression coefficients for each signal channel
n_sig_cha = sig.shape[1]
n_art_sig_cha = art_sig.shape[1]
coefs = np.zeros((n_sig_cha, n_art_sig_cha))
# Process each signal channel separately to reduce memory load
for c in range(n_sig_cha):
# Mean-center the signal channel
sig_cha = sig[:, c]
sig_cha = sig_cha - np.mean(sig_cha, -1, keepdims=True)
sig_cha = sig_cha.reshape(1, -1)
# Perform the least squares regression to estimate coefficients
coefs[c] = np.linalg.lstsq(
cov_ref, art_sig.T @ sig_cha.T, rcond=None)[0].T
# Store the regression coefficients
self.coefs = coefs
self.is_fit = True