Source code for medusa.bci.mi_paradigms

"""Created on Monday February 14 09:27:14 2022

In this module you will find useful functions and classes to operate with data
recorded using motor imagery paradigms, which are widely used by the BCI
community. Enjoy!

@author: Sergio Pérez-Velasco
"""
# Built-in imports
import copy, warnings
from abc import abstractmethod

# External imports
import numpy as np
from tqdm import tqdm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import classification_report

# Medusa imports
from medusa import IIRFilter, car
from medusa.epoching import normalize_epochs
from medusa import get_epochs_of_events, resample_epochs
from medusa import components
from medusa import meeg
from medusa.spatial_filtering import CSP, LaplacianFilter
from medusa.deep_learning_models import EEGSym


[docs]class MIData(components.ExperimentData): # TODO: Check everything """Class with the necessary attributes to define motor imagery (MI) experiments. It provides compatibility with high-level functions for this MI paradigms in BCI module. """
[docs] def __init__(self, mode, onsets, w_trial_t, mi_result=None, calibration_t=None, mi_labels=None, mi_labels_info=None, w_rest_t=None, **kwargs): """MIData constructor Parameters ---------- mode : str {"train"|"test"|"guided_test"} Mode of this run. onsets : list or numpy.ndarray [n_stim x 1] Timestamp of each stimulation w_trial_t: list [start, end] Temporal window of the motor imagery with respect to each onset in ms. For example, if w_trial_t = [500, 4000] the subject was performing the motor imagery task from 500ms after to 4000ms after the onset. calibration_t: int end Time duration of the calibration recorded normally at the beginning of each run. For example, if calibration_t = 5000 the subject was in resting state for the first 5 seconds of the run. mi_result : list or numpy.ndarray [n_mi_labels x 1] Result of this run. Each position contains the data of the selected target at each trial. mi_labels : list or numpy.ndarray [n_mi_labels x 1] Only in train mode. Contains the mi labels of each stimulation, as many as classes in the experiment. mi_labels_info : dict Contains the description of the mi labels. Example: mi_labels_info = {0: "Rest", 1: "Left_hand", 2: "Right_hand"} w_rest_t: list [start, end] Temporal window of the rest with respect to each onset in ms. For example, if w_rest_t = [-1000, 0] the subject was resting from 1000ms before to the onset. kwargs : kwargs Custom arguments that will also be saved in the class (e.g., timings, calibration gaps, etc.) """ # Check errors mode = mode.lower() if mode == 'train': if mi_labels is None: raise ValueError('Attributes mi_labels, ' 'should be provided in train mode') # Standard attributes self.mode = mode self.onsets = np.array(onsets) self.w_trial_t = np.array(w_trial_t) self.calibration_t = np.array(calibration_t) self.mi_result = np.array(mi_result) if mi_result is not None else \ mi_result self.mi_labels = np.array(mi_labels) if mi_labels is not None else \ mi_labels self.mi_labels_info = mi_labels_info self.w_rest_t = np.array(w_rest_t) # Optional attributes for key, value in kwargs.items(): setattr(self, key, value)
[docs] def to_serializable_obj(self): rec_dict = self.__dict__ for key in rec_dict.keys(): if type(rec_dict[key]) == np.ndarray: rec_dict[key] = rec_dict[key].tolist() return rec_dict
[docs] @staticmethod def from_serializable_obj(dict_data): return MIData(**dict_data)
[docs]class MIDataset(components.Dataset): """This class inherits from medusa.data_structures.Dataset, increasing its functionality for datasets with data from MI experiments. It provides common ground for the rest of functions in the module. """
[docs] def __init__(self, channel_set, fs=None, biosignal_att_key='eeg', experiment_att_key='midata', experiment_mode=None, track_attributes=None): """ Constructor Parameters ---------- channel_set : meeg.EEGChannelSet EEG channel set. Only these channels will be kept in the dataset, the others will be discarded. Also, the signals will be rearranged, keeping the same channel order, avoiding errors in future stages of the signal processing pipeline fs : int, float or None Sample rate of the recordings. If there are recordings with different sample rates, the consistency of the dataset can be still assured using resampling biosignal_att_key : str Name of the attribute containing the target biosginal that will be used to extract the features. It has to be the same in all recordings (e.g., 'eeg', 'meg'). experiment_att_key : str or None Name of the attribute containing the target experiment that will be used to extract the features. It has to be the same in all recordings (e.g., 'mi_left_right', 'rest_mi'). It is mandatory when a recording of the dataset contains more than 1 experiment data experiment_mode : str {'train'|'test'|'guided test'|None} Mode of the experiment. If this dataset will be used to fit a model, set to train to avoid errors track_attributes: dict of dicts or None This parameter indicates custom attributes that must be tracked in feature extraction functions and how. The keys are the name of the attributes, whereas the values are dicts indicating the tracking mode {'concatenate'|'append'} and parent. Option concatenate is only available for attributes of type list or numpy arrays, forming a 1 dimensional array with the data from all recordings. Option append is used to save all kind of objects for each recording, forming a list whose length will be the number of recordings in the dataset. A set of default attributes is defined, so this parameter will be None in most cases. Example to track 2 custom attributes (i.e., date and experiment_equipment): track_attributes = { 'date': { 'track_mode': 'append', 'parent': None }, 'experiment_equipment': { 'track_mode': 'append', 'parent': experiment_att_key } } """ # Check errors if experiment_mode is not None: if experiment_mode not in ('train', 'test', 'guided test'): raise ValueError('Parameter experiment_mode must be ' '{train|test|guided test|None}') # Default track attributes default_track_attributes = { 'subject_id': { 'track_mode': 'append', 'parent': None }, 'w_trial_t': { 'track_mode': 'concatenate', 'parent': experiment_att_key }, 'w_rest_t': { 'track_mode': 'concatenate', 'parent': experiment_att_key }, 'calibration_t': { 'track_mode': 'append', 'parent': experiment_att_key }, 'onsets': { 'track_mode': 'concatenate', 'parent': experiment_att_key }, 'mi_labels_info': { 'track_mode': 'append', 'parent': experiment_att_key } } if experiment_mode in ['train', 'guided_test']: default_track_attributes_train = { 'mi_labels': { 'track_mode': 'concatenate', 'parent': experiment_att_key } } default_track_attributes = { **default_track_attributes, **default_track_attributes_train } elif experiment_mode in ['test', 'guided_test']: default_track_attributes_train = { 'mi_result': { 'track_mode': 'concatenate', 'parent': experiment_att_key } } default_track_attributes = { **default_track_attributes, **default_track_attributes_train } track_attributes = \ default_track_attributes if track_attributes is None else \ {**default_track_attributes, **track_attributes} # Class attributes self.channel_set = channel_set self.fs = fs self.biosignal_att_key = biosignal_att_key self.experiment_att_key = experiment_att_key self.experiment_mode = experiment_mode self.track_attributes = track_attributes # Consistency checker checker = self.__get_consistency_checker() super().__init__(consistency_checker=checker)
def __get_consistency_checker(self): """Creates a standard consistency checker for MI datasets Returns ------- checker : data_structures.ConsistencyChecker Standard consistency checker for MI feature extraction """ # Create consistency checker checker = components.ConsistencyChecker() # Check that the biosignal exists checker.add_consistency_rule( rule='check-attribute', rule_params={'attribute': self.biosignal_att_key} ) checker.add_consistency_rule( rule='check-attribute-type', rule_params={'attribute': self.biosignal_att_key, 'type': meeg.EEG} ) # Check channels checker.add_consistency_rule( rule='check-values-in-attribute', rule_params={'attribute': 'channels', 'values': self.channel_set.channels}, parent=self.biosignal_att_key + '.channel_set' ) # Check sample rate if self.fs is not None: checker.add_consistency_rule(rule='check-attribute-value', rule_params={'attribute': 'fs', 'value': self.fs}, parent=self.biosignal_att_key) else: warnings.warn('Parameter fs is None. The consistency of the ' 'dataset cannot be assured. Still, you can use ' 'target_fs parameter for feature extraction ' 'and everything should be fine.') # Check experiment checker.add_consistency_rule( rule='check-attribute', rule_params={'attribute': self.experiment_att_key} ) checker.add_consistency_rule( rule='check-attribute-type', rule_params={'attribute': self.experiment_att_key, 'type': MIData} ) # Check mode # if self.experiment_mode is not None: # checker.add_consistency_rule( # rule='check-attribute-value', # rule_params={'attribute': 'mode', # 'value': self.experiment_mode}, # parent=self.experiment_att_key # ) # Check track_attributes if self.track_attributes is not None: for key, value in self.track_attributes.items(): checker.add_consistency_rule( rule='check-attribute', rule_params={'attribute': key}, parent=value['parent'] ) if value['track_mode'] == 'concatenate': checker.add_consistency_rule( rule='check-attribute-type', rule_params={'attribute': key, 'type': [list, np.ndarray]}, parent=value['parent'] ) return checker
[docs] def custom_operations_on_recordings(self, recording): # Select channels eeg = getattr(recording, self.biosignal_att_key) eeg.change_channel_set(self.channel_set) return recording
[docs]class StandardPreprocessing(components.ProcessingMethod): """Just the common preprocessing applied in MI-based BCI. Simple, quick and effective: frequency IIR filter followed by common average reference (CAR) spatial filter. """
[docs] def __init__(self, order=5, cutoff=[0.05, 63], btype='bandpass', temp_filt_method='sosfiltfilt'): super().__init__(fit_transform_signal=['signal'], fit_transform_dataset=['dataset']) # Parameters self.order = order self.cutoff = cutoff self.btype = btype self.temp_filt_method = temp_filt_method # Variables that self.iir_filter = None
[docs] def fit(self, fs, n_cha=None): """Fits the IIR filter. Parameters ---------- fs: float Sample rate of the signal. n_cha: int Number of channels. Used to compute the initial conditions of the frequency filter. Only required with sosfilt filtering method (online filtering) """ self.iir_filter = IIRFilter(order=self.order, cutoff=self.cutoff, btype=self.btype, filt_method=self.temp_filt_method) self.iir_filter.fit(fs, n_cha=n_cha)
[docs] def transform_signal(self, signal): """Transforms an EEG signal applying IIR filtering and CAR sequentially Parameters ---------- signal: np.array or list Signal to transform. Shape [n_samples x n_channels] """ signal = self.iir_filter.transform(signal) signal = car(signal) return signal
[docs] def fit_transform_signal(self, signal, fs): """Fits the IIR filter and transforms an EEG signal applying IIR filtering and CAR sequentially Parameters ---------- signal: np.array or list Signal to transform. Shape [n_samples x n_channels] fs: float Sample rate of the signal. """ self.iir_filter = IIRFilter(order=self.order, cutoff=self.cutoff, btype=self.btype, filt_method=self.temp_filt_method) signal = self.iir_filter.fit_transform(signal, fs) signal = car(signal) return signal
[docs] def fit_transform_dataset(self, dataset, show_progress_bar=True): """Fits the IIR filter and transforms an EEG signal applying the filter and CAR sequentially. Each recording is preprocessed independently, taking into account possible differences in sample rate. Parameters ---------- dataset: MIDataset MIDataset with the recordings to be preprocessed. show_progress_bar: bool Show progress bar """ pbar = None if show_progress_bar: pbar = tqdm(total=len(dataset.recordings), desc='Preprocessing') for rec in dataset.recordings: eeg = getattr(rec, dataset.biosignal_att_key) eeg.signal = self.fit_transform_signal(eeg.signal, eeg.fs) setattr(rec, dataset.biosignal_att_key, eeg) if show_progress_bar: pbar.update(1) if show_progress_bar: pbar.close() return dataset
[docs]class StandardFeatureExtraction(components.ProcessingMethod): """Standard feature extraction method for MI-based spellers. Basically, it gets the raw epoch for each MI event. """
[docs] def __init__(self, w_epoch_t=(0, 3000), target_fs=128, baseline_mode='trial', w_baseline_t=(-1500, -500), norm='z', concatenate_channels=False, safe_copy=True): """Class constructor w_epoch_t : list Temporal window in ms for each epoch relative to the event onset (e.g., [0, 3000]) target_fs : float of None Target sample rate of each epoch. If None, all the recordings must have the same sample rate, so it is strongly recommended to set this parameter to a suitable value to avoid problems and save time baseline_mode : {'run', 'trial', None} If "run" selected it will use the w_baseline_t of the beginning of each run. If "trial" selected it will use the w_baseline_t previous to each epoch extracted. If None it will not perform the baseline extraction. w_baseline_t : list, np.ndarray or "auto" Temporal window in ms to be used for baseline normalization. If baseline_mode = "run" it will use the temporal window since start of trial. If baseline_mode = "trial", for each epoch it will use the time relative to the event onset (e.g., [-1500, -500]). If "auto" selected: If baseline_mode = "run" will use the time between the MedusaData.experiment.start and MedusaData.experiment.calibration information stored in the run. If baseline_mode = "trial" will use the time between [-1500,-500]ms takes the baseline from -1500ms to -500ms before each onset (0 ms represents the onset). If baseline_mode = None it will not be used. norm : str {'z'|'dc'} Type of baseline normalization. Set to 'z' for Z-score normalization or 'dc' for DC normalization concatenate_channels : bool This parameter controls the shape of the feature array. If True, all channels will be concatenated, returning an array of shape [n_events x (samples x channels)]. If false, the array will have shape [n_events x samples x channels] safe_copy : bool Makes a safe copy of the signal to avoid changing the original samples due to references """ super().__init__(transform_signal=['x'], transform_dataset=['x', 'x_info']) self.w_epoch_t = w_epoch_t self.target_fs = target_fs self.baseline_mode = baseline_mode self.w_baseline_t = w_baseline_t self.norm = norm self.concatenate_channels = concatenate_channels self.safe_copy = safe_copy
[docs] def transform_signal(self, times, signal, fs, onsets): """Function to extract MI features from raw signal. It returns a 3D feature array with shape [n_events x n_samples x n_channels]. This function does not track any other attributes. Use for online processing and custom higher level functions. Parameters ---------- times : list or numpy.ndarray 1D numpy array [n_samples]. Timestamps of each sample. If they are not available, generate them artificially. Nevertheless, all signals and events must have the same temporal origin signal : list or numpy.ndarray 2D numpy array [n_samples x n_channels]. EEG samples (the units should be defined using kwargs) fs : int or float Sample rate of the recording. onsets : list or numpy.ndarray [n_events x 1] Timestamp of each event Returns ------- features : np.ndarray [n_events x n_samples x n_channels] Feature array with the epochs of signal """ # Avoid changes in the original signal (this may not be necessary) if self.safe_copy: signal = signal.copy() # Baseline options norm = self.norm assert self.baseline_mode in ('run', 'trial', None), \ ValueError('Parameter baseline_mode must be {"run", "trial", None}') if self.baseline_mode is None: assert self.w_baseline_t is None, 'If baseline_mode is None, ' \ 'parameter w_baseline_t must be ' \ 'None' w_baseline_trial_t = None elif self.baseline_mode == 'trial': if self.w_baseline_t == 'auto': w_baseline_trial_t = self.w_epoch_t else: w_baseline_trial_t = self.w_baseline_t elif self.baseline_mode == 'run': norm = None w_baseline_trial_t = None # Extract features features = get_epochs_of_events(timestamps=times, signal=signal, onsets=onsets, fs=fs, w_epoch_t=self.w_epoch_t, w_baseline_t=w_baseline_trial_t, norm=norm) if self.baseline_mode == "run": if self.w_baseline_t == 'auto': norm_epoch_t = [0, 5000] else: norm_epoch_t = self.w_baseline_t norm_epoch_s = np.array(norm_epoch_t * fs / 1000, dtype=int) norm_epoch = np.expand_dims(signal[norm_epoch_s], axis=0) else: norm_epoch = None features = normalize_epochs(features, norm_epochs=norm_epoch, norm=self.norm) # Resample each epoch to the target frequency if self.target_fs is not None: if self.target_fs > fs: raise warnings.warn('Target fs is greater than data fs') features = resample_epochs(features, self.w_epoch_t, self.target_fs) # Reshape epochs and concatenate the channels if self.concatenate_channels: features = np.squeeze(features.reshape((features.shape[0], features.shape[1] * features.shape[2], 1))) return features
[docs] def transform_dataset(self, dataset, show_progress_bar=True): #TODO: Review with modifications of Eduardo in erp_spellers (SERGIO) """High level function to easily extract features from EEG recordings and save useful info for later processing. Nevertheless, the provided functionality has several limitations and it will not be suitable for all cases and processing pipelines. If it does not fit your needs, create a custom function iterating the recordings and using extract_erp_features, a much more low-level and general function. This function does not apply any preprocessing to the signals, this must be done before Parameters ---------- dataset: MIDataset List of data_structures.Recordings or data_structures.Dataset. If this parameter is a list of recordings, the consistency of the dataset will be checked. Otherwise, if the parameter is a dataset, this function assumes that the consistency is already checked show_progress_bar: bool Show progress bar Returns ------- features : numpy.ndarray Array with the biosignal samples arranged in epochs track_info : dict Dictionary with tracked information across all recordings """ # Avoid changes in the original recordings (this may not be necessary) if self.safe_copy: dataset = copy.deepcopy(dataset) # Avoid consistency problems if dataset.fs is None and self.target_fs is None: raise ValueError('The consistency of the features is not assured ' 'since dataset.fs and target_fs are both None. ' 'Specify one of these parameters') # Additional track attributes track_attributes = dataset.track_attributes # track_attributes['run_idx'] = { # 'track_mode': 'concatenate', # 'parent': dataset.experiment_att_key # } # Initialization features = None track_info = dict() for key, value in track_attributes.items(): if value['track_mode'] == 'append': track_info[key] = list() elif value['track_mode'] == 'concatenate': track_info[key] = None else: raise ValueError('Unknown track mode') # Init progress bar pbar = None if show_progress_bar: pbar = tqdm(total=len(dataset.recordings), desc='Extracting features') # Compute features for run_counter, rec in enumerate(dataset.recordings): # Extract recording experiment and biosignal rec_exp = getattr(rec, dataset.experiment_att_key) rec_sig = getattr(rec, dataset.biosignal_att_key) # Get features rec_feat = self.transform_signal( times=rec_sig.times, signal=rec_sig.signal, fs=rec_sig.fs, onsets=rec_exp.onsets # , # calibration_t=rec_exp.calibration_t ) features = np.concatenate((features, rec_feat), axis=0) \ if features is not None else rec_feat # Special attributes that need tracking across runs to assure the # consistency of the dataset # rec_exp.run_idx = run_counter * np.ones_like(rec_exp.trial_idx) # Track experiment info for key, value in track_attributes.items(): if value['parent'] is None: parent = rec else: parent = rec for p in value['parent'].split('.'): parent = getattr(parent, p) att = getattr(parent, key) if value['track_mode'] == 'append': track_info[key].append(att) elif value['track_mode'] == 'concatenate': track_info[key] = np.concatenate( (track_info[key], att), axis=0 ) if track_info[key] is not None else att else: raise ValueError('Unknown track mode') if show_progress_bar: pbar.update(1) if show_progress_bar: pbar.close() return features, track_info
[docs]class CSPFeatureExtraction(StandardFeatureExtraction): """Common Spatial Patterns (CSP) feature extraction method for MI-based spellers. Processing pipeline: - Use of StandardFeatureExtraction to get the raw epoch of each MI event. - Extract CSP features of those MI events. """
[docs] def __init__(self, n_filters=4, w_epoch_t=(500, 4000), target_fs=60, baseline_mode='trial', w_baseline_t=(-1500, -500), norm='z', concatenate_channels=False, safe_copy=True): """Class constructor n_filter : int or None Number of most discriminant CSP filters to decompose the signal into (must be less or equal to the number of channels in your signal). If int it will return that number of filters. If None it will return the all calculated filters. w_epoch_t : list Temporal window in ms for each epoch relative to the event onset (e.g., [0, 3000]) target_fs : float of None Target sample rate of each epoch. If None, all the recordings must have the same sample rate, so it is strongly recommended to set this parameter to a suitable value to avoid problems and save time baseline_mode : {'run', 'trial', None} If "run" selected it will use the w_baseline_t of the beginning of each run. If "trial" selected it will use the w_baseline_t previous to each epoch extracted. If None it will not perform the baseline extraction. w_baseline_t : list, np.ndarray or "auto" Temporal window in ms to be used for baseline normalization. If baseline_mode = "run" it will use the temporal window since start of trial. If baseline_mode = "trial", for each epoch it will use the time relative to the event onset (e.g., [0, 3000]). If "auto" selected: If baseline_mode = "run" will use the time between the MedusaData.experiment.start and MedusaData.experiment.calibration information stored in the run. If baseline_mode = "trial" will use the time between [-1500,-500]ms takes the baseline from -1500ms to -500ms before each onset (0 ms represents the onset). If baseline_mode = None it will not be used. norm : str {'z'|'dc'} Type of baseline normalization. Set to 'z' for Z-score normalization or 'dc' for DC normalization concatenate_channels : bool This parameter controls the shape of the feature array. If True, all channels will be concatenated, returning an array of shape [n_events x (samples x channels)]. If false, the array will have shape [n_events x samples x channels] safe_copy : bool Makes a safe copy of the signal to avoid changing the original samples due to references After using the function self.fit(), the attributes are computed: Attributes ---------- CSP : CSP class with attributes filters, eigenvalues, patterns and methods fit and project. filters : {(…, M, M) numpy.ndarray, (…, M, M) matrix} Mixing matrix (spatial filters are stored in columns). eigenvalues : (…, M) numpy.ndarray Eigenvalues of w. patterns : numpy.ndarray De-mixing matrix (activation patterns are stored in columns). """ super().__init__(w_epoch_t=w_epoch_t, target_fs=target_fs, baseline_mode=baseline_mode, w_baseline_t=w_baseline_t, norm=norm, concatenate_channels=concatenate_channels, safe_copy=safe_copy) self.CSP = CSP(n_filters=n_filters)
[docs] def fit(self, X, y): """Train Common Spatial Patterns (CSP) filters Train Common Spatial Patterns (CSP) filters with support to >2 classes based on support multiclass CSP by means of approximate joint diagonalization. In this case, the spatial filter selection is achieved according to [1]. Adapted from http://github.com/alexandrebarachant/pyRiemann Parameters ---------- X : numpy.ndarray, [n_trials, samples, channels] Epoched data of shape (n_trials, samples, channels) y : numpy.ndarray, [n_trials,] Labels for epoched data of shape (n_trials,) References ---------- [1] Grosse-Wentrup, Moritz, and Martin Buss. "Multiclass common spatial patterns and information theoretic feature extraction." Biomedical Engineering, IEEE Transactions on 55, no. 8 (2008): 1991-2000. """ self.CSP.fit(X=X, y=y)
[docs] def project(self, times, signal, fs, onsets): """Function to extract MI features from raw signal. It returns a 3D feature array with shape [n_events x n_samples x n_channels]. This function does not track any other attributes. Use for online processing and custom higher level functions. Parameters ---------- times : list or numpy.ndarray 1D numpy array [n_samples]. Timestamps of each sample. If they are not available, generate them artificially. Nevertheless, all signals and events must have the same temporal origin signal : list or numpy.ndarray 2D numpy array [n_samples x n_channels]. EEG samples (the units should be defined using kwargs) fs : int or float Sample rate of the recording. onsets : list or numpy.ndarray [n_events x 1] Timestamp of each event Returns ------- features : numpy.ndarray [n_events x n_samples x n_channels] Feature array with the epochs of signal projected in the CSP space """ features = super().transform_signal(times, signal, fs, onsets) features = self.CSP.project(X=features) return features
[docs] def extract_log_var_features(self, X): """This method computes the standard motor imagery log-variance features given the CSP . Parameters ---------- X : numpy.ndarray, [n_trials, samples, channels] Epoched data of shape (n_trials, samples, channels) Returns ------- features : numpy.ndarray [n_events x n_channels] Feature array with the epochs of signal projected in the CSP space """ features = self.CSP.project(X=X) # Obtain log-variance features given the CSP features features = np.log(np.var(features, axis=-1)) return features
[docs] def transform_dataset(self, dataset, show_progress_bar=True): features, track_info = super().transform_dataset( dataset=dataset, show_progress_bar=show_progress_bar) # Fit CSP filter self.fit(X=features, y=track_info['mi_labels']) # Project features into CSP space and obtain log-variance features given # the CSP features features = self.extract_log_var_features(X=features) return features, track_info
[docs]class MIModel(components.Algorithm): """Skeleton class for MI-based BCIs models. This class inherits from components.Algorithm. Therefore, it can be used to create standalone algorithms that can be used in compatible apps from medusa-platform for online experiments. See components.Algorithm to know more about this functionality. Related tutorials: - Overview of mi_paradigms module [LINK] - Create standalone models for MI-based BCIs compatible with Medusa platform [LINK] """
[docs] def __init__(self): """Class constructor """ print('MIModel') super().__init__(fit_dataset=['mi_target', 'mi_result', 'accuracy'], predict=['mi_result']) # Settings self.settings = None self.channel_set = None self.configure() # Configuration self.is_configured = False self.is_built = False self.is_fit = False
[docs] @abstractmethod def configure(self, **kwargs): """This function must be used to configure the model before calling build method. Class attribute settings attribute must be set with a dict """ # Update state self.is_configured = True self.is_built = False self.is_fit = False
[docs] @abstractmethod def build(self, *args, **kwargs): """This function builds the model, adding all the processing methods to the pipeline. It must be called after configure. """ # Check errors if not self.is_configured: raise ValueError('Function configure must be called first!') # Update state self.is_built = True self.is_fit = False
[docs] def fit_dataset(self, dataset, **kwargs): """Function that receives an MIDataset and uses its data to fit the model. By default, executes pipeline 'fit_dataset'. Override method for other behaviour. Parameters ---------- dataset: MIDataset Dataset with recordings from an MI-based BCI experiment kwargs: key-value arguments Optional parameters depending on the specific implementation of the model Returns ------- fit_results: dict Dict with the information of the fit process. For command decoding models, at least it has to contain keys mi_target, mi_result and accuracy, which contain the target MI, the decoded MI and the decoding accuracy in the analysis. """ # Check errors if not self.is_built: raise ValueError('Function build must be called first!') # Execute pipeline output = self.exec_pipeline('fit_dataset', dataset=dataset) # Set channels self.channel_set = dataset.channel_set # Update state self.is_fit = True return output
[docs] def predict(self, times, signal, fs, l_cha, x_info, **kwargs): """Function that receives EEG signal and experiment info from an MI-based trial to decode the user's intentions. Used in online experiments. By default, executes pipeline 'predict'. Override method for other behaviour. Parameters --------- times: list or numpy.ndarray Timestamps of the EEG samples signal: list or numpy.ndarray EEG samples with shape [n_samples x n_channels] fs: float Sample rate of the EEG signal l_cha: list List of channel labels x_info: dict Dict with the needed experiment info to decode the commands. It has to contain keys: mode, onsets, w_trial_t. See MIData to know how are defined these variables. kwargs: key-value arguments Optional parameters depending on the specific implementation of the model """ # Check errors if not self.is_fit: raise ValueError('Function fit_dataset must be called first!') # Check channels self.channel_set.check_channels_labels(l_cha, strict=True) # Execute predict pipeline return self.exec_pipeline('predict', times=times, signal=signal, fs=fs, x_info=x_info, **kwargs)
[docs]class MIModelCSP(MIModel): """Decoding model for MI-based BCI applications based on Common Spatial Patterns (CSP). Dataset features: - Sample rate of the signals > 60 Hz. The model can handle recordings with different sample rates. - Recommended channels: ['C3', 'C4']. Processing pipeline: - Preprocessing (medusa.bci.erp_spellers.StandardPreprocessing): - IIR Filter (order=5, cutoff=(8, 30) Hz: unlike FIR filters, IIR filters are quick and can be applied in small signal chunks. Thus, they are the preferred method for frequency filter in online systems. - Common average reference (CAR): widely used spatial filter that increases the signal-to-noise ratio of the MI control signals. - Feature extraction (medusa.bci.mi_paradigms.CSPFeatureExtraction): - Epochs (window=(500, 4000) ms, resampling to 60 HZ): the epochs of signal are extracted for each stimulation. Baseline normalization is also applied, taking the window (-1500, -500) ms relative to the stimulus onset. - CSP projection: Epochs are then projected according to a CSP filter previously trained. - Feature classification ( sklearn.discriminant_analysis.LinearDiscriminantAnalysis) - Regularized linear discriminant analysis (rLDA): we use the sklearn implementation, with eigen solver and auto shrinkage paramers. See reference in sklearn doc. """
[docs] def __init__(self): super().__init__()
[docs] def configure(self, p_filt_cutoff=(8, 30), f_w_epoch_t=(500, 4000), f_target_fs=60): self.settings = { 'p_filt_cutoff': p_filt_cutoff, 'f_w_epoch_t': f_w_epoch_t, 'f_target_fs': f_target_fs } # Update state self.is_configured = True self.is_built = False self.is_fit = False
[docs] def build(self): # Check errors if not self.is_configured: raise ValueError('Function configure must be called first!') # Preprocessing (default: bandpass IIR filter [8, 30] Hz + CAR) self.add_method('prep_method', StandardPreprocessing( cutoff=self.settings['p_filt_cutoff'] )) # Feature extraction (default: epochs [500, 4000] ms + resampling to 80 # Hz) self.add_method('ext_method', CSPFeatureExtraction( w_epoch_t=self.settings['f_w_epoch_t'], target_fs=self.settings['f_target_fs'], )) # Feature classification (rLDA) clf = components.ProcessingClassWrapper( LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto'), fit=[], predict_proba=['y_pred'] ) self.add_method('clf_method', clf) # Update state self.is_built = True self.is_fit = False
[docs] def fit_dataset(self, dataset, **kwargs): # Check errors if not self.is_built: raise ValueError('Function build must be called first!') # Preprocessing dataset = self.get_inst('prep_method').fit_transform_dataset(dataset) # Extract CSP features x, x_info = self.get_inst('ext_method').transform_dataset(dataset) # Classification self.get_inst('clf_method').fit(x, x_info['mi_labels']) y_proba = self.get_inst('clf_method').predict_proba(x) y_pred = self.get_inst('clf_method').predict(x) # Accuracy accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) clf_report = classification_report(x_info['mi_labels'], y_pred, output_dict=True) assessment = { 'x': x, 'x_info': x_info, 'y_proba': y_proba, 'y_pred': y_pred, 'accuracy': accuracy, 'report': clf_report } # Save info self.channel_set = dataset.channel_set # Update state self.is_fit = True return assessment
[docs] def predict(self, times, signal, fs, channel_set, x_info, **kwargs): # Check errors if not self.is_fit: raise ValueError('Function fit_dataset must be called first!') # Check channel set if self.channel_set.channels != channel_set.channels: warnings.warn('The channel set is not the same that was used to ' 'fit the model. Be careful!') # Preprocessing signal = self.get_inst('prep_method').fit_transform_signal(signal, fs) # Extract features x = self.get_inst('ext_method').transform_signal(times, signal, fs, x_info['onsets']) x = self.get_inst('ext_method').extract_log_var_features(x) # Classification y_proba = self.get_inst('clf_method').predict_proba(x) y_pred = self.get_inst('clf_method').predict(x) # Decoding accuracy = None clf_report = None if x_info['mi_labels'] is not None: accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) clf_report = classification_report(x_info['mi_labels'], y_pred, output_dict=True) decoding = { 'x': x, 'x_info': x_info, 'y_proba': y_proba, 'y_pred': y_pred, 'accuracy': accuracy, 'report': clf_report } return decoding
[docs]class MIModelEEGSym(MIModel): """Decoding model for MI-based BCI applications based on EEGSym [1], a deep convolutional neural network developed for inter-subjects MI classification. Dataset features: - Sample rate of the signals > 128 Hz. The model can handle recordings with different sample rates. - Recommended channels: ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4']. Processing pipeline: - Preprocessing: - IIR Filter (order=4, lowpass=49 Hz: unlike FIR filters, IIR filters are quick and can be applied in small signal chunks. Thus, they are the preferred method for frequency filter in online systems - Common average reference (CAR): widely used spatial filter that increases the signal-to-noise ratio. - Feature extraction: - Epochs (window=(0, 2000) ms, resampling to 128 HZ): the epochs of signal are extracted after each onset. Baseline normalization is also applied, taking the same epoch window. - Feature classification - EEGSym: convolutional neural network [1]. References ---------- [1] Pérez-Velasco, S., Santamaría-Vázquez, E., Martínez-Cagigal, V., Marcos-Mateo, D., & Hornero, R. (2020). EEGSym: Overcoming Intersubject Variability in Motor Imagery Based BCIs with Deep Learning. ?. """
[docs] def __init__(self): super().__init__()
[docs] def configure(self, cnn_n_cha=8, ch_lateral=3, fine_tuning=False, shuffle_before_fit=True, validation_split=0.4, init_weights_path=None, gpu_acceleration=False, augmentation=False): self.settings = { 'cnn_n_cha': cnn_n_cha, 'ch_lateral': ch_lateral, 'fine_tuning': fine_tuning, 'augmentation': augmentation, 'shuffle_before_fit': shuffle_before_fit, 'validation_split': validation_split, 'init_weights_path': init_weights_path, 'gpu_acceleration': gpu_acceleration } # Update state self.is_configured = True self.is_built = False self.is_fit = False
[docs] def build(self): # Check errors if not self.is_configured: raise ValueError('Function configure must be called first!') # Only import deep learning models if necessary from medusa.deep_learning_models import EEGSym # Preprocessing (bandpass IIR filter [0.5, 45] Hz + CAR) self.add_method('prep_method', StandardPreprocessing(cutoff=49, btype='lowpass')) # Feature extraction (epochs [0, 2000] ms + resampling to 128 Hz) self.add_method('ext_method', StandardFeatureExtraction(w_epoch_t=( 0, 2000), target_fs=128, w_baseline_t=(0, 2000),)) # Feature classification clf = EEGSym( input_time=2000, fs=128, n_cha=self.settings['cnn_n_cha'], ch_lateral=self.settings['ch_lateral'], filters_per_branch=24, scales_time=(125, 250, 500), dropout_rate=0.4, activation='elu', n_classes=2, learning_rate=0.0001, gpu_acceleration=self.settings['gpu_acceleration']) self.is_fit = False if self.settings['init_weights_path'] is not None: clf.load_weights(self.settings['init_weights_path']) self.channel_set = meeg.EEGChannelSet() standard_lcha = ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4'] self.channel_set.set_standard_montage(standard_lcha) self.is_fit = True self.add_method('clf_method', clf) # Update state self.is_built = True
# self.is_fit = False
[docs] def fit_dataset(self, dataset, **kwargs): # Check errors if not self.is_built: raise ValueError('Function build must be called first!') # Preprocessing dataset = self.get_inst('prep_method').fit_transform_dataset(dataset) # Extract CSP features x, x_info = self.get_inst('ext_method').transform_dataset(dataset) # Put channels in symmetric order x = self.get_inst('clf_method').symmetric_channels(x, dataset.channel_set.l_cha) # Classification self.get_inst('clf_method').fit(x, x_info['mi_labels'], fine_tuning=self.settings['fine_tuning'], shuffle_before_fit=self.settings['shuffle_before_fit'], validation_split=self.settings['validation_split'], augmentation=self.settings['augmentation'], **kwargs) y_prob = self.get_inst('clf_method').predict_proba(x) y_pred = y_prob.argmax(axis=-1) # Accuracy accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) clf_report = classification_report(x_info['mi_labels'], y_pred, output_dict=True) assessment = { 'x': x, 'x_info': x_info, 'y_pred': y_pred, 'y_prob': y_prob, 'accuracy': accuracy, 'report': clf_report } # Save info self.channel_set = dataset.channel_set # Update state self.is_fit = True return assessment
[docs] def predict(self, times, signal, fs, channel_set, x_info, **kwargs): # Check errors if not self.is_fit: raise ValueError('Function fit_dataset must be called first!') # Check channel set if self.channel_set.channels != channel_set.channels: warnings.warn('The channel set is not the same that was used to ' 'fit the model. Be careful!') # Preprocessing signal = self.get_inst('prep_method').fit_transform_signal(signal, fs) # Extract features x = self.get_inst('ext_method').transform_signal(times, signal, fs, x_info['onsets']) # Put channels in symmetric order x = self.get_inst('clf_method').symmetric_channels(x, self.channel_set.l_cha) # Classification y_prob = self.get_inst('clf_method').predict_proba(x) y_pred = y_prob.argmax(axis=-1) # Decoding accuracy = None clf_report = None if x_info['mi_labels'] is not None: accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred) clf_report = classification_report(x_info['mi_labels'], y_pred, output_dict=True) decoding = { 'x': x, 'x_info': x_info, 'y_pred': y_pred, 'y_prob': y_prob, 'accuracy': accuracy, 'report': clf_report } return decoding