# External imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from scipy.signal import welch
# Medusa imports
from medusa.plots.head_plots import plot_head,plot_topography
from medusa.meeg.meeg import EEGChannelSet
from medusa import epoching
from medusa.plots.timeplot import time_plot
from medusa.components import SerializableComponent
[docs]def reject_noisy_epochs(epochs, signal_mean, signal_std, k=4, n_samp=2,
n_cha=1):
"""Simple thresholding method to reject noisy epochs. It discards epochs
with n_samp samples greater than k*std in n_cha channels
Parameters
----------
epochs : list or numpy.ndarray
Epochs of signal with dimensions [n_epochs x samples x channels]
signal_mean : float
Mean of the signal
signal_std : float
Standard deviation of the signal
k : float
Standard deviation multiplier to calculate threshold
n_samp : int
Minimum number of samples that have to be over the threshold in each
epoch to be discarded
n_cha : int
Minimum number of channels that have to have n_samples over the
threshold in each epoch to be discarded
Returns
-------
float
Percentage of reject epochs in
numpy.ndarray
Clean epochs
numpy.ndarray
Indexes for rejected epochs. True for discarded epoch
"""
# Check errors
if len(epochs.shape) != 3:
raise Exception('Malformed epochs array. It must be of dimmensions '
'[epochs x samples x channels]')
if signal_std.shape[0] != epochs.shape[2]:
raise Exception('Array signal_std does not match with epochs size. '
'It must have the same number of channels')
if signal_mean.shape[0] != epochs.shape[2]:
raise Exception('Array signal_mean does not match with epochs size. '
'It must have the same number of channels')
epochs_abs = np.abs(epochs)
cmp = epochs_abs > np.abs(signal_mean) + k * signal_std
idx = np.sum((np.sum(cmp, axis=1) >= n_samp), axis=1) >= n_cha
pct_rejected = (np.sum(idx) / epochs.shape[0]) * 100
return pct_rejected, epochs[~idx, :, :], idx
[docs]class ICA:
[docs] def __init__(self, random_state=None):
self.n_components = None
self.random_state = random_state
self.pre_whitener = None
self.unmixing_matrix = None
self.mixing_matrix = None
self._ica_n_iter = None
self.components_excluded = None
self.ica_labels = None
# PCA attributes
self.n_pca_components = None
self._pca_mean = None
self._pca_explained_variance = None
self._pca_explained_variance_ratio = None
self._pca_components = None
# Signal attributes
self.l_cha = None
self.channel_set = None
self.fs = None
[docs] def fit(self, signal, l_cha, fs, n_components):
from sklearn.decomposition import FastICA
from scipy import linalg
# Set arguments
self.n_components = n_components
self.l_cha = l_cha
self.fs = fs
# Define MEEG Channel Set
self.channel_set = EEGChannelSet()
self.channel_set.set_standard_montage(self.l_cha)
signal, _ = self._check_signal_dimensions(signal.copy())
self._get_pre_whitener(signal)
signal_pca = self._whitener(signal.copy())
ica = FastICA(whiten=False, random_state=self.random_state,
max_iter=1000)
ica.fit(signal_pca[:, :self.n_components])
self.unmixing_matrix = ica.components_
self._ica_n_iter = ica.n_iter_
assert self.unmixing_matrix.shape == (self.n_components,) * 2
# Whitening unmixing matrix
norm = np.sqrt(self._pca_explained_variance[:self.n_components])
norm[norm == 0] = 1.
self.unmixing_matrix /= norm
self.mixing_matrix = linalg.pinv(self.unmixing_matrix)
# Sort ica components from greater to lower explained variance
self._sort_components(signal)
self.ica_labels = [f"ICA_{n}" for n in range(self.n_components)]
[docs] def get_sources(self, signal):
if not hasattr(self, 'mixing_matrix'):
raise RuntimeError('ICA has not been fitted yet. Please, fit ICA.')
signal, _ = self._check_signal_dimensions(signal.copy())
signal = self._pre_whiten(signal)
if self._pca_mean is not None:
signal -= self._pca_mean
# Transform signal to PCA space and then apply unmixing matrix
pca_transform = np.dot(self._pca_components[:self.n_components],
signal.T)
sources = np.dot(self.unmixing_matrix, pca_transform)
return sources.T
[docs] def rebuild(self, signal, exclude=None):
signal, n_epo_samples = self._check_signal_dimensions(signal.copy())
signal = self._pre_whiten(signal)
if exclude is not None:
if isinstance(exclude, int):
exclude = [exclude]
exclude = np.array(list(set(exclude)))
if len(np.where(exclude > self.n_components)[0]) > 0:
raise ValueError("One or more ICA component keys that you have"
"marked to exclude from signal rebuild are "
"greater than the total number of "
"ICA components.")
self.components_excluded = exclude
# Apply PCA
if self._pca_mean is not None:
signal -= self._pca_mean
# Determine ica components to keep in signal rebuild
c_to_keep = np.setdiff1d(np.arange(self.n_components), exclude)
# Define projection matrix
proj = np.dot(np.dot(self._pca_components[:self.n_components].T,
self.mixing_matrix[:, c_to_keep]),
np.dot(self.unmixing_matrix[c_to_keep, :],
self._pca_components[:self.n_components]))
# Apply projection to signal
signal_rebuilt = np.transpose(np.dot(proj, signal.T))
if self._pca_mean is not None:
signal_rebuilt += self._pca_mean
signal_rebuilt *= self.pre_whitener
# Restore epochs if original signal was divided in epochs
if n_epo_samples is not None:
signal_rebuilt = np.reshape(signal_rebuilt,
(int(signal_rebuilt.shape[
0] / n_epo_samples)
, n_epo_samples,
signal_rebuilt.shape[1]))
return signal_rebuilt
[docs] def get_components(self):
# Transform
components = np.dot(self.mixing_matrix[:, :self.n_components].T,
self._pca_components[:self.n_components]).T
return components
[docs] def save(self, path):
if not hasattr(self, 'mixing_matrix'):
raise RuntimeError('ICA has not been fitted yet. Please, fit ICA.')
ica_data = ICAData(self.pre_whitener, self.unmixing_matrix,
self.mixing_matrix, self.n_components,
self._pca_components, self._pca_mean,
self.components_excluded,
self.random_state)
ica_data.save(path, 'bson')
[docs] def load(self, path):
# Load ICAData instance
ica_data = ICAData().load_from_bson(path)
# Update ICA arguments
self.pre_whitener = ica_data.pre_whitener
self.unmixing_matrix = np.array(ica_data.unmixing_matrix)
self.mixing_matrix = np.array(ica_data.mixing_matrix)
self.n_components = ica_data.n_components
self._pca_components = np.array(ica_data.pca_components)
self._pca_mean = np.array(ica_data.pca_mean)
self.components_excluded = np.array(ica_data.components_excluded)
self.random_state = ica_data.random_state
# ------------------------- AUXILIARY METHODS ------------------------------
@staticmethod
def _check_signal_dimensions(signal):
n_epo_samples = None
# Check input dimensions
if len(signal.shape) == 3:
# Stack the epochs
n_epo_samples = signal.shape[1]
signal = np.vstack(signal)
elif len(signal.shape) == 1:
raise ValueError("signal input only has one dimension, but two"
"dimensions (n_samples, n_channels) arer nedded at"
"least")
return signal, n_epo_samples
def _get_pre_whitener(self, signal):
self.pre_whitener = np.std(signal)
def _pre_whiten(self, signal):
signal /= self.pre_whitener
return signal
def _whitener(self, signal):
from sklearn.decomposition import PCA
signal_pw = self._pre_whiten(signal)
pca = PCA(n_components=None, whiten=True)
signal_pca = pca.fit_transform(signal_pw)
self._pca_mean = pca.mean_
self._pca_components = pca.components_
self._pca_explained_variance = pca.explained_variance_
self._pca_explained_variance_ratio = pca.explained_variance_ratio_
self.n_pca_components = pca.n_components_
del pca
# Check a correct input of n_components parameter
if self.n_components is None:
self.n_components = min(self.n_pca_components,
self._exp_var_ncomp(
self._pca_explained_variance_ratio,
0.99999)[0])
elif isinstance(self.n_components, float):
self.n_components = self._exp_var_ncomp(
self._pca_explained_variance_ratio, self.n_components)[0]
if self.n_components == 1:
raise RuntimeError(
'One PCA component captures most of the '
'explained variance, your threshold '
'results in 1 component. You should select '
'a higher value.')
else:
if not isinstance(self.n_components, int):
raise ValueError(
f'n_components={self.n_components} must be None,'
f'float or int value')
if self.n_components > self.n_pca_components:
raise ValueError(f'The number of ICA components('
f'n_components={self.n_components}) must be '
f'lower than the number of PCA components'
f'({self.n_pca_components}).')
return signal_pca
def _sort_components(self, signal):
sources = self.get_sources(signal)
meanvar = np.sum(self.mixing_matrix ** 2, axis=0) * \
np.sum(sources ** 2, axis=0) / \
(sources.shape[0] * sources.shape[1] - 1)
c_order = np.argsort(meanvar)[::-1]
self.unmixing_matrix = self.unmixing_matrix[c_order, :]
self.mixing_matrix = self.mixing_matrix[:, c_order]
@staticmethod
def _exp_var_ncomp(var, n):
cvar = np.asarray(var, dtype=np.float64)
cvar = cvar.cumsum()
cvar /= cvar[-1]
# We allow 1., which would give us N+1
n = min((cvar <= n).sum() + 1, len(cvar))
return n, cvar[n - 1]
# --------------------------- PLOT METHODS ---------------------------------
[docs] def plot_components(self, cmap='bwr'):
# Get ICA components
components = self.get_components()
# Define subplot parameters
n_components = components.shape[1]
if n_components <= 5:
cols = n_components
rows = 1
else:
cols = 5
rows = np.ceil(n_components / 5)
# Define subplot
fig, axes = plt.subplots(int(rows), int(cols))
if len(axes.shape) == 1:
axes = np.array([axes])
# Topo plots
ic_c = 0
for r in axes:
for c in r:
if ic_c < n_components:
plot_head(axes=c, channel_set=self.channel_set,
interp_points=300,linewidth=1.5)
plot_topography(axes=c,channel_set=self.channel_set,
values=components[:, ic_c],
interp_points=300,cmap=cmap,
show_colorbar=False,plot_extra=0)
c.set_title(self.ica_labels[ic_c])
ic_c += 1
else:
c.set_axis_off()
fig.show()
return fig
[docs] def plot_sources(self, signal, sources_to_show=None, time_to_show=None,
ch_offset=None):
sources = self.get_sources(signal)
if ch_offset is None:
ch_offset = np.max(np.abs(sources[:, 0]))
fig, ax = plt.subplots(1,1)
time_plot(sources, self.fs, self.ica_labels,
ch_to_show=sources_to_show,
time_to_show=time_to_show,
ch_offset=ch_offset,show=False,fig=fig,
axes=ax)
[docs] def plot_summary(self, signal, component, psd_freq_range=[1,70],
psd_window='hamming', time_to_show=2,cmap='bwr'):
# Check error
if isinstance(component,int):
component = np.array([component])
elif isinstance(component,list):
component = np.array(component)
else:
raise ValueError("Component parameter must be a int or list of int"
"of the ICA components.")
if np.any(component > self.n_components):
raise ValueError("There is a component greater than the total"
f"number of ICA components:{self.n_components}.")
# Check if signal is epoched
n_samples_epoch = None
if len(signal.shape) == 3:
n_samples_epoch = signal.shape[1]
n_stacks = signal.shape[0]
time_to_show = int(n_samples_epoch/self.fs)
sources = self.get_sources(signal)[:,component]
components = self.get_components()
if n_samples_epoch is None:
n_samples_epoch = int(time_to_show * self.fs)
n_stacks = (int(len(sources[:, 0]) / n_samples_epoch))
for ii in range(len(component)):
fig = plt.figure()
ax_1 = fig.add_subplot(3,4,(1,6))
ax_2 = fig.add_subplot(3, 4, (3,4))
ax_3 = fig.add_subplot(3, 4, (7,8))
ax_4 =fig.add_subplot(3, 1, 3)
# Topoplot
plot_head(axes=ax_1, channel_set=self.channel_set,
interp_points=300, linewidth=1.5)
plot_topography(axes=ax_1, channel_set=self.channel_set,
values=components[:, component[ii]],
interp_points=300, cmap=cmap,
show_colorbar=False, plot_extra=0)
ax_1.set_title(self.ica_labels[component[ii]])
stacked_source = np.reshape(
sources[:(n_stacks * n_samples_epoch), ii],
(n_stacks, n_samples_epoch))
# PSD
f, psd = welch(stacked_source, self.fs, window=psd_window, )
f_range = np.logical_and(f>=psd_freq_range[0],f<=psd_freq_range[1])
psd_mean = np.mean(10*np.log10(psd),axis=0)
psd_std = np.std(10*np.log10(psd),axis=0)
ax_2.fill_between(f[f_range], (psd_mean-psd_std)[f_range],
(psd_mean+psd_std)[f_range],
color='k',alpha=0.3)
ax_2.plot(f[f_range],psd_mean[f_range],'k')
ax_2.set_xlim(f[f_range][0],f[f_range][-1])
ax_2.set_xlabel('Frequency (Hz)')
ax_2.set_ylabel('Power/Hz (dB)')
ax_2.set_title('Power spectral density')
# Stacked data image
ax_3.pcolormesh(np.linspace(0,time_to_show,n_samples_epoch),
np.arange(n_stacks),stacked_source,cmap=cmap,
shading='gouraud')
ax_3.set_xlabel('Time (s)')
ax_3.set_ylabel('Segments')
ax_3.set_title('Stacked source segments')
# Time plot
time_plot(sources[:,ii],self.fs,[self.ica_labels[component[ii]]],
time_to_show=time_to_show,fig=fig,axes=ax_4)
ax_4.set_title('Source time plot')
fig.tight_layout(pad=1)
return fig
[docs] def show_exclusion(self, signal, exclude=None, ch_to_show=None,
time_to_show=None, ch_offset=None):
if ch_offset is None:
ch_offset = np.max(np.abs(signal.copy()))
# Check if signal is divided in epochs
if len(signal.shape) == 3:
n_epochs = signal.shape[0]
else:
n_epochs = 1
fig, ax = plt.subplots(1,1)
time_plot(signal,self.fs,self.l_cha,time_to_show,
ch_to_show,ch_offset,axes=ax,fig=fig)
signal_rebuilt = self.rebuild(signal,exclude)
time_plot(signal_rebuilt,self.fs,self.l_cha,time_to_show,
ch_to_show,ch_offset,fig=fig,axes=ax,color='b',
show_epoch_lines=False)
#Create legend
handles = [fig.axes[0].lines[0],fig.axes[0].lines[signal.shape[-1]
+ n_epochs -1]]
fig.axes[0].legend(handles=handles,labels=['Pre-ICA','Post-ICA'],
loc='upper center', bbox_to_anchor=(0.5, 1.15),
ncol=3, fancybox=True, shadow=True)
[docs]class ICAData(SerializableComponent):
[docs] def __init__(self, pre_whitener=None, unmixing_matrix=None,
mixing_matrix=None,
n_components=None, pca_components=None, pca_mean=None,
components_excluded=None, random_state=None):
# General parameters
self.n_components = n_components
self.random_state = random_state
self.pre_whitener = pre_whitener
# ICA
self.unmixing_matrix = unmixing_matrix
self.mixing_matrix = mixing_matrix
self.components_excluded = components_excluded
# PCA
self.pca_components = pca_components
self.pca_mean = pca_mean
[docs] def to_serializable_obj(self):
rec_dict = self.__dict__
for key in rec_dict.keys():
if type(rec_dict[key]) == np.ndarray:
rec_dict[key] = rec_dict[key].tolist()
return rec_dict
[docs] @classmethod
def from_serializable_obj(cls, dict_data):
return cls(**dict_data)