Source code for medusa.bci.mi_paradigms
"""
In this module you will find useful functions and classes to operate with data
recorded using motor imagery paradigms, which are widely used by the BCI
community. Enjoy!
@author: Sergio Pérez-Velasco & Víctor Martínez-Cagigal
"""
# Built-in imports
import copy, warnings
from abc import abstractmethod
# External imports
import numpy as np
from tqdm import tqdm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import classification_report
# Medusa imports
from medusa import IIRFilter, car
from medusa.epoching import normalize_epochs
from medusa import get_epochs_of_events, resample_epochs
from medusa import components
from medusa import meeg
from medusa.spatial_filtering import CSP, LaplacianFilter
from medusa import classification_utils as cu
[docs]class MIData(components.ExperimentData):
# TODO: Check everything
"""Class with the necessary attributes to define motor imagery (MI)
experiments. It provides compatibility with high-level functions for this
MI paradigms in BCI module.
"""
[docs] def __init__(self, mode, onsets, w_trial_t,
calibration_onset_w=None, w_preparation_t=None, w_rest_t=None,
mi_labels=None, mi_labels_info=None, mi_result=None,
paradigm_info=None, **kwargs):
"""MIData constructor
Parameters
----------
mode : str {"train"|"test"|"guided test"}
Mode of this run.
onsets : list or numpy.ndarray [n_stim x 1]
Timestamp of the cue with respect to the EEG signal (just when the
motor imagery starts).
w_trial_t: list [start, end]
Temporal window of the motor imagery with respect to each onset in
ms. For example, if w_trial_t = [500, 4000] the subject was
performing the motor imagery task from 500ms to 4000ms after
the onset.
calibration_onset_w: list [start, end]
Timestamps of the onsets regarding the calibration window,
if exists, respect to the EEG signal.
w_preparation_t: list [start, end]
Temporal window of the preparation time (no motor imagery),
if exists, with respect to each onset in ms.
w_rest_t: list [start, end]
Temporal window of the rest time (no motor imagery),
if exists, with respect to each onset in ms.
mi_labels : list or numpy.ndarray [n_mi_labels x 1]
Only in train mode. Contains the mi labels of each stimulation,
as many as classes in the experiment.
mi_labels_info : dict
Contains the description of the mi labels.
Example:
mi_labels_info =
{0: "Rest",
1: "Left_hand",
2: "Right_hand"}
mi_result : list or numpy.ndarray [n_mi_labels x 1]
Result of this run. Each position contains the data of the
selected target at each trial.
paradigm_info: dict()
Recommended but not mandatory. Use this variable to keep the
information regarding the different timings of the paradigm.
kwargs : kwargs
Custom arguments that will also be saved in the class
(e.g., timings, calibration gaps, etc.)
"""
# Check errors
mode = mode.lower()
if mode == 'train':
if mi_labels is None:
raise ValueError('Attribute "mi_labels" must be provided in '
'train mode')
# Standard attributes
self.mode = mode
self.onsets = np.array(onsets)
self.w_trial_t = np.array(w_trial_t)
self.calibration_onset_w = np.array(calibration_onset_w)
self.w_preparation_t = np.array(w_preparation_t)
self.w_rest_t = np.array(w_rest_t)
self.mi_labels = np.array(mi_labels) if mi_labels is not None else \
mi_labels
self.mi_labels_info = mi_labels_info
self.mi_result = np.array(mi_result) if mi_result is not None else \
mi_result
self.paradigm_info = paradigm_info
# Optional attributes
for key, value in kwargs.items():
setattr(self, key, value)
[docs] def to_serializable_obj(self):
rec_dict = self.__dict__
for key in rec_dict.keys():
if type(rec_dict[key]) == np.ndarray:
rec_dict[key] = rec_dict[key].tolist()
return rec_dict
[docs]class MIDataset(components.Dataset):
"""This class inherits from medusa.data_structures.Dataset, increasing
its functionality for datasets with data from MI experiments. It
provides common ground for the rest of functions in the module.
"""
[docs] def __init__(self, channel_set, fs=None, biosignal_att_key='eeg',
experiment_att_key='midata', experiment_mode=None,
track_attributes=None):
""" Constructor
Parameters
----------
channel_set : meeg.EEGChannelSet
EEG channel set. Only these channels will be kept in the dataset,
the others will be discarded. Also, the signals will be rearranged,
keeping the same channel order, avoiding errors in future stages of
the signal processing pipeline
fs : int, float or None
Sample rate of the recordings. If there are recordings with
different sample rates, the consistency of the dataset can be
still assured using resampling
biosignal_att_key : str
Name of the attribute containing the target biosginal that will be
used to extract the features. It has to be the same in all
recordings (e.g., 'eeg', 'meg').
experiment_att_key : str or None
Name of the attribute containing the target experiment that will be
used to extract the features. It has to be the same in all
recordings (e.g., 'mi_left_right', 'rest_mi'). It is
mandatory when a recording of the dataset contains more than 1
experiment data
experiment_mode : str {'train'|'test'|'guided test'|None}
Mode of the experiment. If this dataset will be used to fit a model,
set to train to avoid errors
track_attributes: dict of dicts or None
This parameter indicates custom attributes that must be tracked in
feature extraction functions and how. The keys are the name of the
attributes, whereas the values are dicts indicating the tracking
mode {'concatenate'|'append'} and parent. Option concatenate is
only available for attributes of type list or numpy arrays,
forming a 1 dimensional array with the data from all recordings.
Option append is used to save all kind of objects for each
recording, forming a list whose length will be the number of
recordings in the dataset. A set of default attributes is defined,
so this parameter will be None in most cases. Example to track 2
custom attributes (i.e., date and experiment_equipment):
track_attributes = {
'date': {
'track_mode': 'append',
'parent': None
},
'experiment_equipment': {
'track_mode': 'append',
'parent': experiment_att_key
}
}
"""
# Check errors
if experiment_mode is not None:
if experiment_mode not in ('train', 'test', 'guided test'):
raise ValueError('Parameter experiment_mode must be '
'{train|test|guided test|None}')
# Default track attributes
default_track_attributes = {
'subject_id': {
'track_mode': 'append',
'parent': None
},
'onsets': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'w_trial_t': {
'track_mode': 'append',
'parent': experiment_att_key
},
'calibration_onset_w': {
'track_mode': 'append',
'parent': experiment_att_key
},
'w_preparation_t': {
'track_mode': 'append',
'parent': experiment_att_key
},
'w_rest_t': {
'track_mode': 'append',
'parent': experiment_att_key
},
'mi_labels_info': {
'track_mode': 'append',
'parent': experiment_att_key
},
'paradigm_info': {
'track_mode': 'append',
'parent': experiment_att_key
}
}
if experiment_mode in ['train', 'guided test']:
default_track_attributes_train = {
'mi_labels': {
'track_mode': 'concatenate',
'parent': experiment_att_key
}
}
default_track_attributes = {
**default_track_attributes,
**default_track_attributes_train
}
if experiment_mode in ['test', 'guided test']:
default_track_attributes_test = {
'mi_result': {
'track_mode': 'concatenate',
'parent': experiment_att_key
}
}
default_track_attributes = {
**default_track_attributes,
**default_track_attributes_test
}
track_attributes = \
default_track_attributes if track_attributes is None else \
{**default_track_attributes, **track_attributes}
# Class attributes
self.channel_set = channel_set
self.fs = fs
self.biosignal_att_key = biosignal_att_key
self.experiment_att_key = experiment_att_key
self.experiment_mode = experiment_mode
self.track_attributes = track_attributes
# Consistency checker
checker = self.__get_consistency_checker()
super().__init__(consistency_checker=checker)
def __get_consistency_checker(self):
"""Creates a standard consistency checker for MI datasets
Returns
-------
checker : data_structures.ConsistencyChecker
Standard consistency checker for MI feature extraction
"""
# Create consistency checker
checker = components.ConsistencyChecker()
# Check that the biosignal exists
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': self.biosignal_att_key}
)
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': self.biosignal_att_key,
'type': meeg.EEG}
)
# Check channels
checker.add_consistency_rule(
rule='check-values-in-attribute',
rule_params={'attribute': 'channels',
'values': self.channel_set.channels},
parent=self.biosignal_att_key + '.channel_set'
)
# Check sample rate
if self.fs is not None:
checker.add_consistency_rule(rule='check-attribute-value',
rule_params={'attribute': 'fs',
'value': self.fs},
parent=self.biosignal_att_key)
else:
warnings.warn('Parameter fs is None. The consistency of the '
'dataset cannot be assured. Still, you can use '
'target_fs parameter for feature extraction '
'and everything should be fine.')
# Check experiment
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': self.experiment_att_key}
)
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': self.experiment_att_key,
'type': MIData}
)
# Check track_attributes
if self.track_attributes is not None:
for key, value in self.track_attributes.items():
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': key},
parent=value['parent']
)
if value['track_mode'] == 'concatenate':
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': key,
'type': [list, np.ndarray, dict]},
parent=value['parent']
)
return checker
[docs] def custom_operations_on_recordings(self, recording):
# Select channels
eeg = getattr(recording, self.biosignal_att_key)
eeg.change_channel_set(self.channel_set)
return recording
[docs]class StandardPreprocessing(components.ProcessingMethod):
"""Just the common preprocessing applied in MI-based BCI. Simple,
quick and effective: frequency IIR filter followed by common average
reference (CAR) spatial filter.
"""
[docs] def __init__(self, order=5, cutoff=[8, 30], btype='bandpass',
temp_filt_method='sosfiltfilt'):
super().__init__(fit_transform_signal=['signal'],
fit_transform_dataset=['dataset'])
# Parameters
self.order = order
self.cutoff = cutoff
self.btype = btype
self.temp_filt_method = temp_filt_method
# Variables that
self.iir_filter = None
[docs] def fit(self, fs, n_cha=None):
"""Fits the IIR filter.
Parameters
----------
fs: float
Sample rate of the signal.
n_cha: int
Number of channels. Used to compute the initial conditions of the
frequency filter. Only required with sosfilt filtering method
(online filtering)
"""
self.iir_filter = IIRFilter(order=self.order,
cutoff=self.cutoff,
btype=self.btype,
filt_method=self.temp_filt_method)
self.iir_filter.fit(fs, n_cha=n_cha)
[docs] def transform_signal(self, signal):
"""Transforms an EEG signal applying IIR filtering and CAR sequentially
Parameters
----------
signal: np.array or list
Signal to transform. Shape [n_samples x n_channels]
"""
signal = self.iir_filter.transform(signal)
signal = car(signal)
return signal
[docs] def transform_dataset(self, dataset: MIDataset, show_progress_bar=True,
**kwargs):
"""Transforms an MIDataset applying IIR filtering and CAR sequentially
Parameters
----------
signal: np.array or list
Signal to transform. Shape [n_samples x n_channels]
"""
pbar = None
if show_progress_bar:
pbar = tqdm(total=len(dataset.recordings),
desc='Preprocessing')
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = self.transform_signal(eeg.signal)
setattr(rec, dataset.biosignal_att_key, eeg)
if show_progress_bar:
pbar.update(1)
if show_progress_bar:
pbar.close()
return dataset
[docs] def fit_transform_signal(self, signal, fs):
"""Fits the IIR filter and transforms an EEG signal applying IIR
filtering and CAR sequentially
Parameters
----------
signal: np.array or list
Signal to transform. Shape [n_samples x n_channels]
fs: float
Sample rate of the signal.
"""
self.iir_filter = IIRFilter(order=self.order,
cutoff=self.cutoff,
btype=self.btype,
filt_method=self.temp_filt_method)
signal = self.iir_filter.fit_transform(signal, fs)
signal = car(signal)
return signal
[docs] def fit_transform_dataset(self, dataset, show_progress_bar=True):
"""Fits the IIR filter and transforms an EEG signal applying the
filter and CAR sequentially. Each recording is preprocessed
independently, taking into account possible differences in sample rate.
Parameters
----------
dataset: MIDataset
MIDataset with the recordings to be preprocessed.
show_progress_bar: bool
Show progress bar
"""
pbar = None
if show_progress_bar:
pbar = tqdm(total=len(dataset.recordings),
desc='Preprocessing')
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = self.fit_transform_signal(eeg.signal, eeg.fs)
setattr(rec, dataset.biosignal_att_key, eeg)
if show_progress_bar:
pbar.update(1)
if show_progress_bar:
pbar.close()
return dataset
[docs]class StandardFeatureExtraction(components.ProcessingMethod):
"""
Standard feature extraction method for MI-based spellers. Basically,
it gets the raw epoch for each MI event.
"""
[docs] def __init__(self, safe_copy=True):
"""
Class constructor. All parameters except "safe_copy" must be specified
in each method.
Parameters
-----------
safe_copy : bool
Makes a safe copy of the signal to avoid changing the original
samples due to references.
"""
super().__init__(transform_signal=['x'],
transform_dataset=['x', 'x_info'])
self.safe_copy = safe_copy
[docs] def transform_signal(self, times, signal, fs, onsets, w_epoch_t=(0, 3000),
baseline_mode="trial", w_baseline_t=(-1500, -500),
norm='z', target_fs=128, concatenate_channels=False,
**kwargs):
"""
Method to extract epochs from an individual signal.
Parameters
-----------------
times : ndarray (n_samples,)
Timestamp array
signal: ndarray (n_samples x n_channels)
Signal data.
fs : int
Sampling frequency.
onsets: ndarray (n_trials,)
Timestamps for the MI event onsets.
w_epoch_t : list, tuple, or ndarray
Temporal window in ms for each epoch relative to the event onset
(e.g., [0, 3000])
baseline_mode : basestring {'run', 'trial', None}
If "run", the baseline will be extracted from the very beginning
of the run (i.e., the first samples of the signal).
If "trial" (default), the baseline will be extracted for each
trial relative to the onset.
If None, no baseline extraction will be performed.
w_baseline_t : list, tuple, or ndarray
Temporal window in ms to be used for baseline normalization.
If baseline_mode = "run", the window is relative to the start of
the signal.
If baseline_mode = "trial", the window is relative to each trial
onset (e.g., [-1500, -500]).
norm : str {'z'|'dc'}
Type of baseline normalization. Set to 'z' for Z-score
normalization (subtract the mean and divide by the std),
or 'dc' for DC normalization (subtract the mean).
target_fs : float of None
Target sample rate of each epoch. If None, no resampling will be
applied. Please note that, in this case, all the recordings must
have the same sample rate.
concatenate_channels : bool
This parameter controls the shape of the feature array. If True, all
channels will be concatenated, returning an array of shape [n_events
x (samples x channels)]. If false, the array will have shape
[n_events x samples x channels].
Returns
-----------
features : ndarray
MI epochs extracted with shape [n_events x samples x channels], or
[n_events x (samples x channels)] if concatenate_channels == True.
"""
# Avoid changes in the original signal (this may not be necessary)
if self.safe_copy:
signal = signal.copy()
# Baseline options
if baseline_mode == "sliding":
baseline_mode = "trial"
assert baseline_mode in ('run', 'trial', None), \
ValueError('Parameter baseline_mode must be {"run", "trial", None}')
if baseline_mode in ('run', None):
w_baseline_t = None
norm = None
# Extract features
features = get_epochs_of_events(
timestamps=times, signal=signal, onsets=onsets, fs=fs,
w_epoch_t=w_epoch_t,
w_baseline_t=w_baseline_t, norm=norm
)
# Common baseline for the entire run
if baseline_mode == "run":
# Here baseline is taken relative to the start of the signal
norm_epoch_s = np.array(w_baseline_t * fs / 1000, dtype=int)
norm_epoch = np.expand_dims(signal[norm_epoch_s, :], axis=0)
features = normalize_epochs(features, norm_epochs=norm_epoch,
norm=self.norm)
# Apply resampling if required
if target_fs is not None:
if target_fs > fs:
raise warnings.warn('Target fs is greater than data fs')
features = resample_epochs(features, w_epoch_t, target_fs)
# Channel concatenation if desired
if concatenate_channels:
features = np.squeeze(features.reshape(
(features.shape[0], features.shape[1] * features.shape[2], 1)))
return features
[docs] def transform_dataset(self, dataset, show_progress_bar=True,
w_epoch_t=(0, 3000), baseline_mode="trial",
w_baseline_t=(-1500, -500), norm='z',
target_fs=128, concatenate_channels=False,
sliding_w_lims_t=None, sliding_t_step=None,
sliding_win_len=None, **kwargs):
"""
Method to extract epochs from an entire MIDataset
High level function to easily extract features from EEG recordings
and save useful info for later processing. Nevertheless, the provided
functionality has several limitations and it will not be suitable for
all cases and processing pipelines. If it does not fit your needs,
create a custom function iterating the recordings and using
extract_erp_features, a much more low-level and general function. This
function does not apply any preprocessing to the signals, this must
be done before.
Most parameters are shared with
StandardFeatureExtraction.transform_signal() method, check it for
more information. Only the new parameters or those that behave
differently than the aforementioned method are detailed below. These
new parameters are related to a continuous extraction of epochs using a
sliding window between a desired epoch range, instead of extracting only
one epoch for each onset.
Parameters
----------
dataset: MIDataset
MIDataset instance containing the MIData recordings.
show_progress_bar: bool
Boolean to show (or not) the progress bar info.
sliding_w_lims_t: list, tuple, ndarray, or None
2D window that indicates the range (start, end) for the sliding
window approach. If None, no sliding window would be applied.
This parameter delimits the number of windows to be extracted for
each onset, so none of them can fall outside the range (excluding
baseline).
sliding_t_step: int, None
Step in samples that is used to separate the sliding windows. If
None, no sliding window would be applied.
sliding_win_len: int, None
Length in samples of the sliding windows. Please note that the
w_epoch_t parameter would not be used if the sliding window approach
is used. If None, no sliding window would be applied.
baseline_mode : {'run', 'trial', None, 'sliding'}
The baseline_mode has an additional feature when sliding window is
used:
- "run": common baseline extracted from the start of the run.
- "trial": common baseline for the trial, i.e., all the windows
that belongs to an onset; extracted relative to the onset.
- None: no baseline.
- "sliding": baseline is applied to each sliding window,
and it is relative to the start of each sliding window.
Returns
-------
features : numpy.ndarray
Array with the biosignal samples arranged in epochs.
track_info : dict
Dictionary with tracked information across all recordings.
"""
# Continuous settings
use_continuous = False
if sliding_w_lims_t is not None or sliding_t_step is not None or \
sliding_win_len is not None:
assert sliding_w_lims_t is not None and sliding_t_step is not \
None and sliding_win_len is not None, \
ValueError("All these three parameters {sliding_w_lims_t, "
"sliding_t_step, sliding_win_len} must be "
"specified to use sliding windows. If not desired, "
"then put all three to None. Aborting.")
use_continuous = True
# Avoid changes in the original recordings (this may not be necessary)
if self.safe_copy:
dataset = copy.deepcopy(dataset)
# Avoid consistency problems
if dataset.fs is None and target_fs is None:
raise ValueError('The consistency of the features is not assured '
'since dataset.fs and target_fs are both None. '
'Specify one of these parameters')
# Additional track attributes
track_attributes = dataset.track_attributes
# Initialization
track_info = dict()
for key, value in track_attributes.items():
if value['track_mode'] == 'append':
track_info[key] = list()
elif value['track_mode'] == 'concatenate':
track_info[key] = None
else:
raise ValueError('Unknown track mode')
# Init progress bar
pbar = None
if show_progress_bar:
pbar = tqdm(total=len(dataset.recordings),
desc='Extracting features')
# Compute features
features = None
for run_counter, rec in enumerate(dataset.recordings):
# Extract recording experiment and biosignal
rec_exp = getattr(rec, dataset.experiment_att_key)
rec_sig = getattr(rec, dataset.biosignal_att_key)
# If continuous: get different windows relative to the onsets
if use_continuous:
# Get windows
start, stop = sliding_w_lims_t # respect to each onset
step_array = np.arange(0, (stop - start - sliding_win_len) //
sliding_t_step + 1)
win_0 = start + step_array * sliding_t_step
win_1 = win_0 + sliding_win_len
list_w_epoch_t = [(w0, w1) for w0, w1 in zip(win_0, win_1)
if w1 <= stop]
# Get baselines
if baseline_mode == "sliding":
bas_0 = start + w_baseline_t[0] + step_array * sliding_t_step
bas_1 = start + w_baseline_t[1] + step_array * sliding_t_step
list_w_bas_t = [(w0, w1) for w0, w1 in zip(bas_0, bas_1)]
elif baseline_mode == "trial":
bas_0 = start + w_baseline_t[0] + step_array * 0
bas_1 = start + w_baseline_t[1] + step_array * 0
list_w_bas_t = [(w0, w1) for w0, w1 in zip(bas_0, bas_1)]
elif baseline_mode == "run":
list_w_bas_t = [w_baseline_t for _ in range(len(
list_w_epoch_t))]
else:
list_w_bas_t = [None for _ in range(len(list_w_epoch_t))]
# Modify the mi_labels
if hasattr(rec.midata, "mi_labels"):
rec.midata.mi_labels = \
np.tile(rec.midata.mi_labels,
(len(list_w_epoch_t), 1)).flatten()
else:
list_w_epoch_t = [w_epoch_t]
list_w_bas_t = [w_baseline_t]
# For each window
for i in range(len(list_w_epoch_t)):
# Get features
rec_feat = self.transform_signal(
times=rec_sig.times,
signal=rec_sig.signal,
fs=rec_sig.fs,
onsets=rec_exp.onsets,
w_epoch_t=list_w_epoch_t[i],
w_baseline_t=list_w_bas_t[i],
baseline_mode=baseline_mode,
norm=norm,
target_fs=target_fs,
concatenate_channels=concatenate_channels
)
features = np.concatenate((features, rec_feat), axis=0) \
if features is not None else rec_feat
# Track experiment info
for key, value in track_attributes.items():
if value['parent'] is None:
parent = rec
else:
parent = rec
for p in value['parent'].split('.'):
parent = getattr(parent, p)
att = getattr(parent, key)
if value['track_mode'] == 'append':
track_info[key].append(att)
elif value['track_mode'] == 'concatenate':
track_info[key] = np.concatenate(
(track_info[key], att), axis=0
) if track_info[key] is not None else att
else:
raise ValueError('Unknown track mode')
if show_progress_bar:
pbar.update(1)
if show_progress_bar:
pbar.close()
return features, track_info
[docs]class CSPFeatureExtraction(components.ProcessingMethod):
"""Common Spatial Patterns (CSP) feature extraction method for MI-based
spellers.
Processing pipeline:
- Use of StandardFeatureExtraction to get the raw epoch of each MI event.
- Extract CSP features of those MI events.
- Log-var features
"""
[docs] def __init__(self, n_filters=4, safe_copy=True,
normalize_log_vars=True, **kwargs):
"""Class constructor.
n_filters : int or None
Number of most discriminant CSP filters to decompose the signal
into (must be less or equal to the number of channels in your
signal). If None, all filters will be used.
safe_copy : bool
Makes a safe copy of the signal to avoid changing the original
samples due to references
normalize_log_vars : bool
If true, log-var features are normalized.
**kwargs : dict()
Parameters from StandardFeatureExtraction are not detailed here.
Please refer to the StandardFeatureExtraction documentation to know
more, as they are passed through kwargs.
After using the function self.fit(), the attributes are computed:
Attributes
----------
CSP : CSP class with attributes filters, eigenvalues, patterns and
methods fit and project.
filters : {(…, M, M) numpy.ndarray, (…, M, M) matrix}
Mixing matrix (spatial filters are stored in columns).
eigenvalues : (…, M) numpy.ndarray
Eigenvalues of w.
patterns : numpy.ndarray
De-mixing matrix (activation patterns are stored in columns).
"""
self.normalize_log_vars = normalize_log_vars
self.feature_extractor = StandardFeatureExtraction(safe_copy=safe_copy)
self.CSP = CSP(n_filters=n_filters, selection="extremes")
[docs] def fit_signal(self, labels, times, signal, fs, onsets, **kwargs):
# Standard extraction: just epoching (epochs x samples x channels)
features = self.feature_extractor.transform_signal(
times=times, signal=signal, fs=fs, onsets=onsets, **kwargs
)
# Fit the CSP filter
self.CSP.fit(X=features, y=labels)
return features
[docs] def fit_dataset(self, dataset, show_progress_bar=True, **kwargs):
# Standard extraction: just epoching (epochs x samples x channels)
features, track_info = self.feature_extractor.transform_dataset(
dataset=dataset, show_progress_bar=show_progress_bar, **kwargs
)
# Fit the CSP filter
self.CSP.fit(X=features, y=track_info['mi_labels'])
return features, track_info
[docs] def transform_signal(self, times, signal, fs, onsets, **kwargs):
if not self.CSP.is_fitted:
raise ValueError('CSPFeatureExtraction must be fitted before '
'predict!')
# Standard extraction: just epoching (epochs x samples x channels)
features = self.feature_extractor.transform_signal(
times=times, signal=signal, fs=fs, onsets=onsets, **kwargs
)
# Project using CSP
projection = self.CSP.project(X=features) # trials x samples x projects
# Log-variance features across samples
log_var = self._get_log_vars(projection, self.normalize_log_vars)
return log_var
[docs] def transform_dataset(self, dataset, show_progress_bar=True, **kwargs):
if not self.CSP.is_fitted:
raise ValueError('CSPFeatureExtraction must be fitted before '
'predict!')
# Standard extraction: just epoching (epochs x samples x channels)
features, track_info = self.feature_extractor.transform_dataset(
dataset=dataset, show_progress_bar=show_progress_bar, **kwargs
)
# Project using CSP
projection = self.CSP.project(X=features)
# Log-variance features
log_var = self._get_log_vars(projection, self.normalize_log_vars)
return log_var, track_info
[docs] def fit_transform_signal(self, labels, times, signal, fs, onsets, **kwargs):
# Fit
features = self.fit_signal(labels, times, signal, fs, onsets, **kwargs)
# Project using CSP
projection = self.CSP.project(X=features) # trials x samples x projects
# Log-variance features across samples
log_var = self._get_log_vars(projection, self.normalize_log_vars)
return log_var
[docs] def fit_transform_dataset(self, dataset, show_progress_bar=True, **kwargs):
# Fit
features, track_info = self.fit_dataset(
dataset, show_progress_bar=show_progress_bar, **kwargs)
# Project using CSP
projection = self.CSP.project(X=features)
# Log-variance features
log_var = self._get_log_vars(projection, self.normalize_log_vars)
return log_var, track_info
@staticmethod
def _get_log_vars(projection, normalize):
if normalize:
return np.log(
np.var(projection, axis=1) /
np.tile(np.sum(np.var(projection, axis=1), axis=1),
(projection.shape[2], 1)).T
)
else:
return np.log(np.var(projection, axis=1))
[docs]class MIModel(components.Algorithm):
"""Skeleton class for MI-based BCIs models. This class inherits from
components.Algorithm. Therefore, it can be used to create standalone
algorithms that can be used in compatible apps from medusa-platform
for online experiments. See components.Algorithm to know more about this
functionality.
Related tutorials:
- Overview of mi_paradigms module [LINK]
- Create standalone models for MI-based BCIs compatible with
Medusa platform [LINK]
"""
[docs] def __init__(self):
"""Class constructor
"""
super().__init__(fit_dataset=['mi_target', 'mi_result', 'accuracy'],
predict=['mi_result'])
# Settings
self.settings = None
self.channel_set = None
self.configure()
# Configuration
self.is_configured = False
self.is_built = False
self.is_fit = False
[docs] @abstractmethod
def configure(self, **kwargs):
"""This function must be used to configure the model before calling
build method. Class attribute settings attribute must be set with a dict
"""
# Update state
self.is_configured = True
self.is_built = False
self.is_fit = False
[docs] @abstractmethod
def build(self, *args, **kwargs):
"""This function builds the model, adding all the processing methods
to the pipeline. It must be called after configure.
"""
# Check errors
if not self.is_configured:
raise ValueError('Function configure must be called first!')
# Update state
self.is_built = True
self.is_fit = False
[docs] def fit_dataset(self, dataset, **kwargs):
"""Function that receives an MIDataset and uses its data to
fit the model. By default, executes pipeline 'fit_dataset'. Override
method for other behaviour.
Parameters
----------
dataset: MIDataset
Dataset with recordings from an MI-based BCI experiment
kwargs: key-value arguments
Optional parameters depending on the specific implementation of
the model
Returns
-------
fit_results: dict
Dict with the information of the fit process. For command
decoding models, at least it has to contain keys
mi_target, mi_result and accuracy, which contain the target MI,
the decoded MI and the decoding accuracy in the analysis.
"""
# Check errors
if not self.is_built:
raise ValueError('Function build must be called first!')
# Execute pipeline
output = self.exec_pipeline('fit_dataset', dataset=dataset)
# Set channels
self.channel_set = dataset.channel_set
# Update state
self.is_fit = True
return output
[docs] def predict(self, times, signal, fs, l_cha, x_info, **kwargs):
"""Function that receives EEG signal and experiment info from an
MI-based trial to decode the user's intentions. Used in online
experiments. By default, executes pipeline 'predict'. Override method
for other behaviour.
Parameters
---------
times: list or numpy.ndarray
Timestamps of the EEG samples
signal: list or numpy.ndarray
EEG samples with shape [n_samples x n_channels]
fs: float
Sample rate of the EEG signal
l_cha: list
List of channel labels
x_info: dict
Dict with the needed experiment info to decode the commands. It
has to contain keys: mode, onsets, w_trial_t.
See MIData to know how are defined these variables.
kwargs: key-value arguments
Optional parameters depending on the specific implementation of
the model
"""
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Check channels
self.channel_set.check_channels_labels(l_cha, strict=True)
# Execute predict pipeline
return self.exec_pipeline('predict', times=times, signal=signal,
fs=fs, x_info=x_info, **kwargs)
[docs]class MIModelCSP(MIModel):
"""Decoding model for MI-based BCI applications based on
Common Spatial Patterns (CSP).
It is strongly recommended to update the default settings by passing
arguments through **kwargs. See the different algorithms to know more
about the possible parameters: StandardPreprocessing,
CSPFeatureExtraction and StandardFeatureExtraction.
Dataset features:
- Sample rate of the signals > 60 Hz. The model can handle recordings
with different sample rates.
- Recommended channels to be present: ['C3', 'C4'].
Processing pipeline:
- Preprocessing (medusa.bci.erp_spellers.StandardPreprocessing):
- IIR Filter (order=5, cutoff=(8, 30) Hz: unlike FIR filters, IIR
filters are quick and can be applied in small signal chunks. Thus,
they are the preferred method for frequency filter in online
systems.
- Common average reference (CAR): widely used spatial filter that
increases the signal-to-noise ratio of the MI control signals.
- Feature extraction (medusa.bci.mi_paradigms.CSPFeatureExtraction):
- Epochs (window=(0, 2000) ms, resampling to 60 HZ): the epochs of
signal are extracted for each stimulation. Baseline normalization
is also applied, taking the window (-1000, 0) ms relative to the
stimulus onset.
- CSP projection: Epochs are then projected according to a CSP filter
previously trained.
- Log variance: Log variance features are extracted from the CSP
projection.
- Feature classification (
sklearn.discriminant_analysis.LinearDiscriminantAnalysis)
- Regularized linear discriminant analysis (rLDA): we use the sklearn
implementation, with eigen solver and auto shrinkage paramers. See
reference in sklearn doc.
"""
[docs] def configure(self, p_filt_cutoff=(8, 30), w_epoch_t=(0, 2000),
w_baseline_t=(-1000, 0), target_fs=60, **kwargs):
""" Configures the default settings. """
self.settings = {
# StandardPreprocessing
'p_filt_cutoff': p_filt_cutoff,
# CSPFeatureExtraction
'n_filters': 4,
'normalize_log_vars': False,
# StandardFeatureExtraction
'w_epoch_t': w_epoch_t,
'baseline_mode': 'trial',
'w_baseline_t': w_baseline_t,
'norm': 'z',
'target_fs': target_fs,
'concatenate_channels': False,
'sliding_w_lims_t': None,
'sliding_t_step': None,
'sliding_win_len': None
}
self.settings = dict(self.settings, **kwargs)
# Update state
self.is_configured = True
self.is_built = False
self.is_fit = False
[docs] def build(self):
""" Initializes the different methods that comprise the MIModelCSP
pipeline.
"""
# Check errors
if not self.is_configured:
raise ValueError('Function configure must be called first!')
# Preprocessing (default: bandpass IIR filter [8, 30] Hz + CAR)
self.add_method('prep_method', StandardPreprocessing(
cutoff=self.settings['p_filt_cutoff']
))
# Feature extraction (default: epochs [500, 4000] ms + resampling to 80
# Hz)
self.add_method('ext_method', CSPFeatureExtraction(**self.settings))
# Feature classification (rLDA)
clf = components.ProcessingClassWrapper(
LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto'),
fit=[], predict_proba=['y_pred']
)
self.add_method('clf_method', clf)
# Update state
self.is_built = True
self.is_fit = False
[docs] def fit_dataset(self, dataset, get_training_accuracy=True,
k_fold=5, **kwargs):
""" Function to fit a dataset using MIModelCSP.
Parameters
-------------
dataset : MIDataset
MI dataset used for training.
get_training_accuracy : bool
Whether to perform an estimation on training accuracy after fitting.
k_fold : int
Number of k-folds used in the k-fold cross-validation procedure
to estimate the training accuracy.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
assessment : dict
Dictionary containing the details of the training accuracy
estimation.
"""
# Check errors
if not self.is_built:
raise ValueError('Function build must be called first!')
# Merge settings
settings = dict(self.settings, **kwargs)
self.channel_set = dataset.channel_set
# Preprocessing
dataset = self.get_inst('prep_method').fit_transform_dataset(dataset)
# Extract CSP features
x, x_info = self.get_inst('ext_method').fit_transform_dataset(
dataset, **settings)
# Training accuracy using a k-fold cross-validation
assessment = None
if get_training_accuracy:
folds_decoding = list()
k_fold_iter = cu.k_fold_split(x, x_info['mi_labels'], k_fold)
k_fold_acc = 0
for iter in k_fold_iter:
self.get_inst('clf_method').fit(
iter["x_train"], iter["y_train"])
y_test_pred = self.get_inst('clf_method').predict(
iter["x_test"])
y_test_prob = self.get_inst('clf_method').predict_proba(
iter["x_test"])
fold_acc = np.sum(y_test_pred == iter["y_test"]) / \
len(iter["y_test"])
k_fold_acc += fold_acc
folds_decoding.append({
"y_pred": y_test_pred,
"y_prob": y_test_prob,
"accuracy": fold_acc,
})
k_fold_acc /= len(k_fold_iter)
assessment = {
'x': x,
'x_info': x_info,
'k-fold': folds_decoding,
'accuracy': k_fold_acc
}
# Fit classifier with all the data
self.get_inst('clf_method').fit(x, x_info['mi_labels'])
# Save info
self.channel_set = dataset.channel_set
# Update state
self.is_fit = True
return assessment
[docs] def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
""" Function to predict an individual signal in MIModelCSP.
Parameters
--------------
times : ndarray (n_samples,)
Timestamp array
signal: ndarray (n_samples x n_channels)
Signal data.
fs : int
Sampling frequency.
channel_set : EEGChannelSet or similar
Channel montage.
x_info : dict
Dictionary containing the trial "onsets" and "mi_labels". If the
latter is not specified, accuracy is not calculated.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
decoding : dict
Dictionary containing the decoding.
"""
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Merge settings
settings = dict(self.settings, **kwargs)
# Check channel set
if self.channel_set.l_cha != channel_set.l_cha:
warnings.warn('The channel set is not the same that was used to '
'fit the model. Be careful!')
# Preprocessing
signal = self.get_inst('prep_method').transform_signal(signal)
# Extract features
x = self.get_inst('ext_method').transform_signal(
times, signal, fs, x_info['onsets'], **settings)
# Classification
y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = self.get_inst('clf_method').predict(x)
# Decoding
accuracy = None
clf_report = None
if x_info['mi_labels'] is not None:
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
clf_report = classification_report(x_info['mi_labels'], y_pred,
output_dict=True)
decoding = {
'x': x,
'x_info': x_info,
'y_prob': y_prob,
'y_pred': y_pred,
'accuracy': accuracy,
'report': clf_report
}
return decoding
[docs] def predict_dataset(self, dataset, **kwargs):
""" Function to predict a dataset using MIModelCSP.
Parameters
-------------
dataset : MIDataset
Test dataset.
**kwargs : dict
These parameters will be overwritten over self.settings.
Returns
-------------
decoding : dict
Dictionary containing the decoding.
"""
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Check channel set
if self.channel_set.l_cha != dataset.channel_set.l_cha:
warnings.warn('The channel set is not the same that was used to '
'fit the model. Be careful!')
# Merge settings
settings = dict(self.settings, **kwargs)
# Preprocessing
dataset = self.get_inst('prep_method').transform_dataset(dataset,
**settings)
# Extract features
x, x_info = self.get_inst('ext_method').transform_dataset(
dataset, **settings)
# Classification
y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = self.get_inst('clf_method').predict(x)
# Decoding
accuracy = None
clf_report = None
if x_info['mi_labels'] is not None:
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
clf_report = classification_report(x_info['mi_labels'], y_pred,
output_dict=True)
decoding = {
'x': x,
'x_info': x_info,
'y_prob': y_prob,
'y_pred': y_pred,
'accuracy': accuracy,
'report': clf_report
}
return decoding
[docs]class MIModelEEGSym(MIModel):
"""Decoding model for MI-based BCI applications based on EEGSym [1], a deep
convolutional neural network developed for inter-subjects MI classification.
Dataset features:
- Sample rate of the signals > 128 Hz. The model can handle recordings
with different sample rates.
- Recommended channels: ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4'].
Processing pipeline:
- Preprocessing:
- IIR Filter (order=4, lowpass=49 Hz: unlike FIR filters, IIR filters
are quick and can be applied in small signal chunks. Thus,
they are the preferred method for frequency filter in online systems
- Common average reference (CAR): widely used spatial filter that
increases the signal-to-noise ratio.
- Feature extraction:
- Epochs (window=(0, 2000) ms, resampling to 128 HZ): the epochs of
signal are extracted after each onset. Baseline normalization
is also applied, taking the same epoch window.
- Feature classification
- EEGSym: convolutional neural network [1].
References
----------
[1] Pérez-Velasco, S., Santamaría-Vázquez, E., Martínez-Cagigal, V.,
Marcos-Mateo, D., & Hornero, R. (2020). EEGSym: Overcoming Intersubject
Variability in Motor Imagery Based BCIs with Deep Learning. ?.
"""
[docs] def configure(self, cnn_n_cha=8, ch_lateral=3, fine_tuning=False,
shuffle_before_fit=True, validation_split=0.4,
init_weights_path=None, gpu_acceleration=False,
augmentation=False):
self.settings = {
'cnn_n_cha': cnn_n_cha,
'ch_lateral': ch_lateral,
'fine_tuning': fine_tuning,
'augmentation': augmentation,
'shuffle_before_fit': shuffle_before_fit,
'validation_split': validation_split,
'init_weights_path': init_weights_path,
'gpu_acceleration': gpu_acceleration
}
# Update state
self.is_configured = True
self.is_built = False
self.is_fit = False
[docs] def build(self):
# Check errors
if not self.is_configured:
raise ValueError('Function configure must be called first!')
# Only import deep learning models if necessary
from medusa.deep_learning_models import EEGSym
# Preprocessing (bandpass IIR filter [0.5, 45] Hz + CAR)
self.add_method('prep_method',
StandardPreprocessing(cutoff=49, btype='lowpass'))
# Feature extraction (epochs [0, 2000] ms + resampling to 128 Hz)
self.add_method('ext_method', StandardFeatureExtraction(safe_copy=True))
# Feature classification
clf = EEGSym(
input_time=2000,
fs=128,
n_cha=self.settings['cnn_n_cha'],
ch_lateral=self.settings['ch_lateral'],
filters_per_branch=24,
scales_time=(125, 250, 500),
dropout_rate=0.4,
activation='elu', n_classes=2,
learning_rate=0.0001,
gpu_acceleration=self.settings['gpu_acceleration'])
self.is_fit = False
if self.settings['init_weights_path'] is not None:
clf.load_weights(self.settings['init_weights_path'])
self.channel_set = meeg.EEGChannelSet()
standard_lcha = ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4']
self.channel_set.set_standard_montage(standard_lcha)
self.is_fit = True
self.add_method('clf_method', clf)
# Update state
self.is_built = True
# self.is_fit = False
[docs] def fit_dataset(self, dataset, continuous=False, **kwargs):
# Check errors
if not self.is_built:
raise ValueError('Function build must be called first!')
# Preprocessing
dataset = self.get_inst('prep_method').fit_transform_dataset(dataset)
# Extract features
# TODO: perform sliding window? if so, we need the lims
# TODO: hardcoded?
x, x_info = self.get_inst('ext_method').transform_dataset(
dataset, w_epoch_t=(0, 2000), baseline_mode="trial",
w_baseline_t=(0, 2000), norm="z", target_fs=128,
sliding_w_lims_t=None, sliding_t_step=None, sliding_win_len=None
)
# Put channels in symmetric order
x, _ = self.get_inst('clf_method').symmetric_channels(
x, dataset.channel_set.l_cha)
# Classification
self.get_inst('clf_method').fit(
x, x_info['mi_labels'],
fine_tuning=self.settings['fine_tuning'],
shuffle_before_fit=self.settings['shuffle_before_fit'],
validation_split=self.settings['validation_split'],
augmentation=self.settings['augmentation'],
**kwargs)
y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = y_prob.argmax(axis=-1)
# Accuracy
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
clf_report = classification_report(x_info['mi_labels'], y_pred,
output_dict=True)
assessment = {
'x': x,
'x_info': x_info,
'y_pred': y_pred,
'y_prob': y_prob,
'accuracy': accuracy,
'report': clf_report
}
# Save info
self.channel_set = dataset.channel_set
# Update state
self.is_fit = True
return assessment
[docs] def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
# Check errors
if not self.is_fit:
raise ValueError('Function fit_dataset must be called first!')
# Check channel set
if self.channel_set.channels != channel_set.channels:
warnings.warn('The channel set is not the same that was used to '
'fit the model. Be careful!')
# Preprocessing
signal = self.get_inst('prep_method').fit_transform_signal(signal, fs)
# Extract features
# TODO: hardcoded?
x = self.get_inst('ext_method').transform_signal(
times=times, signal=signal, fs=fs, onsets=x_info['onsets'],
w_epoch_t=(0, 2000), baseline_mode="trial", w_baseline_t=(0, 2000),
norm="z", target_fs=128
)
# Put channels in symmetric order
x, _ = self.get_inst('clf_method').symmetric_channels(
x, self.channel_set.l_cha)
# Classification
y_prob = self.get_inst('clf_method').predict_proba(x)
y_pred = y_prob.argmax(axis=-1)
# Decoding
accuracy = None
clf_report = None
if x_info['mi_labels'] is not None:
accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
clf_report = classification_report(x_info['mi_labels'], y_pred,
output_dict=True)
decoding = {
'x': x,
'x_info': x_info,
'y_pred': y_pred,
'y_prob': y_prob,
'accuracy': accuracy,
'report': clf_report
}
return decoding