Source code for medusa.bci.mi_paradigms
"""Created on Monday February 14 09:27:14 2022
In this module you will find useful functions and classes to operate with data
recorded using motor imagery paradigms, which are widely used by the BCI
community. Enjoy!
@author: Sergio Pérez-Velasco
"""
# Built-in imports
import copy, warnings
from abc import abstractmethod
# External imports
import numpy as np
from tqdm import tqdm
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import classification_report
# Medusa imports
from medusa import IIRFilter, car
from medusa.epoching import normalize_epochs
from medusa import get_epochs_of_events, resample_epochs
from medusa import components
from medusa import meeg
from medusa.spatial_filtering import CSP, LaplacianFilter
from medusa.deep_learning_models import EEGSym
[docs]class MIData(components.ExperimentData):
    # TODO: Check everything
    """Class with the necessary attributes to define motor imagery (MI)
    experiments. It provides compatibility with high-level functions for this
    MI paradigms in BCI module.
    """
[docs]    def __init__(self, mode, onsets, w_trial_t, mi_result=None,
                 calibration_t=None, mi_labels=None, mi_labels_info=None,
                 w_rest_t=None, **kwargs):
        """MIData constructor
        Parameters
        ----------
        mode : str {"train"|"test"|"guided_test"}
            Mode of this run.
        onsets : list or numpy.ndarray [n_stim x 1]
            Timestamp of each stimulation
        w_trial_t: list [start, end]
            Temporal window of the motor imagery with respect to each onset in
            ms. For example, if  w_trial_t = [500, 4000] the subject was
            performing the motor imagery task from 500ms after to 4000ms after
            the onset.
        calibration_t: int end
            Time duration of the calibration recorded normally at the
            beginning of each run. For example, if calibration_t = 5000 the
            subject was in resting state for the first 5 seconds of the run.
        mi_result : list or numpy.ndarray [n_mi_labels x 1]
            Result of this run. Each position contains the data of the
            selected target at each trial.
        mi_labels : list or numpy.ndarray [n_mi_labels x 1]
            Only in train mode. Contains the mi labels of each stimulation,
            as many as classes in the experiment.
        mi_labels_info : dict
            Contains the description of the mi labels.
            Example:
                mi_labels_info =
                    {0: "Rest",
                    1: "Left_hand",
                    2: "Right_hand"}
        w_rest_t: list [start, end]
            Temporal window of the rest with respect to each onset in ms. For
            example, if w_rest_t = [-1000, 0] the subject was resting from
            1000ms before to the onset.
        kwargs : kwargs
            Custom arguments that will also be saved in the class
            (e.g., timings, calibration gaps, etc.)
        """
        # Check errors
        mode = mode.lower()
        if mode == 'train':
            if mi_labels is None:
                raise ValueError('Attributes mi_labels, '
                                 'should be provided in train mode')
        # Standard attributes
        self.mode = mode
        self.onsets = np.array(onsets)
        self.w_trial_t = np.array(w_trial_t)
        self.calibration_t = np.array(calibration_t)
        self.mi_result = np.array(mi_result) if mi_result is not None else \
            mi_result
        self.mi_labels = np.array(mi_labels) if mi_labels is not None else \
            mi_labels
        self.mi_labels_info = mi_labels_info
        self.w_rest_t = np.array(w_rest_t)
        # Optional attributes
        for key, value in kwargs.items():
            setattr(self, key, value)
[docs]    def to_serializable_obj(self):
        rec_dict = self.__dict__
        for key in rec_dict.keys():
            if type(rec_dict[key]) == np.ndarray:
                rec_dict[key] = rec_dict[key].tolist()
        return rec_dict
[docs]class MIDataset(components.Dataset):
    """This class inherits from medusa.data_structures.Dataset, increasing
    its functionality for datasets with data from MI experiments. It
    provides common ground for the rest of functions in the module.
    """
[docs]    def __init__(self, channel_set, fs=None, biosignal_att_key='eeg',
                 experiment_att_key='midata', experiment_mode=None,
                 track_attributes=None):
        """ Constructor
        Parameters
        ----------
        channel_set : meeg.EEGChannelSet
            EEG channel set. Only these channels will be kept in the dataset,
            the others will be discarded. Also, the signals will be rearranged,
            keeping the same channel order, avoiding errors in future stages of
            the signal processing pipeline
        fs : int, float or None
            Sample rate of the recordings. If there are recordings with
            different sample rates, the consistency of the dataset can be
            still assured using resampling
        biosignal_att_key : str
            Name of the attribute containing the target biosginal that will be
            used to extract the features. It has to be the same in all
            recordings (e.g., 'eeg', 'meg').
        experiment_att_key : str or None
            Name of the attribute containing the target experiment that will be
            used to extract the features. It has to be the same in all
            recordings (e.g., 'mi_left_right', 'rest_mi'). It is
            mandatory when a recording of the dataset contains more than 1
            experiment data
        experiment_mode : str {'train'|'test'|'guided test'|None}
            Mode of the experiment. If this dataset will be used to fit a model,
            set to train to avoid errors
        track_attributes: dict of dicts or None
            This parameter indicates custom attributes that must be tracked in
            feature extraction functions and how. The keys are the name of the
            attributes, whereas the values are dicts indicating the tracking
            mode {'concatenate'|'append'} and parent. Option concatenate is
            only available for attributes of type list or numpy arrays,
            forming a 1 dimensional array with the data from all recordings.
            Option append is used to save all kind of objects for each
            recording, forming a list whose length will be the number of
            recordings in the dataset. A set of default attributes is defined,
            so this parameter will be None in most cases. Example to track 2
            custom attributes (i.e., date and experiment_equipment):
                track_attributes = {
                    'date': {
                        'track_mode': 'append',
                        'parent': None
                    },
                    'experiment_equipment': {
                        'track_mode': 'append',
                        'parent': experiment_att_key
                    }
                }
        """
        # Check errors
        if experiment_mode is not None:
            if experiment_mode not in ('train', 'test', 'guided test'):
                raise ValueError('Parameter experiment_mode must be '
                                 '{train|test|guided test|None}')
        # Default track attributes
        default_track_attributes = {
            'subject_id': {
                'track_mode': 'append',
                'parent': None
            },
            'w_trial_t': {
                'track_mode': 'concatenate',
                'parent': experiment_att_key
            },
            'w_rest_t': {
                'track_mode': 'concatenate',
                'parent': experiment_att_key
            },
            'calibration_t': {
                'track_mode': 'append',
                'parent': experiment_att_key
            },
            'onsets': {
                'track_mode': 'concatenate',
                'parent': experiment_att_key
            },
            'mi_labels_info': {
                'track_mode': 'append',
                'parent': experiment_att_key
            }
        }
        if experiment_mode in ['train', 'guided_test']:
            default_track_attributes_train = {
                'mi_labels': {
                    'track_mode': 'concatenate',
                    'parent': experiment_att_key
                }
            }
            default_track_attributes = {
                **default_track_attributes,
                **default_track_attributes_train
            }
        elif experiment_mode in ['test', 'guided_test']:
            default_track_attributes_train = {
                'mi_result': {
                    'track_mode': 'concatenate',
                    'parent': experiment_att_key
                }
            }
            default_track_attributes = {
                **default_track_attributes,
                **default_track_attributes_train
            }
        track_attributes = \
            default_track_attributes if track_attributes is None else \
                {**default_track_attributes, **track_attributes}
        # Class attributes
        self.channel_set = channel_set
        self.fs = fs
        self.biosignal_att_key = biosignal_att_key
        self.experiment_att_key = experiment_att_key
        self.experiment_mode = experiment_mode
        self.track_attributes = track_attributes
        # Consistency checker
        checker = self.__get_consistency_checker()
        super().__init__(consistency_checker=checker)
    def __get_consistency_checker(self):
        """Creates a standard consistency checker for MI datasets
        Returns
        -------
        checker : data_structures.ConsistencyChecker
            Standard consistency checker for MI feature extraction
        """
        # Create consistency checker
        checker = components.ConsistencyChecker()
        # Check that the biosignal exists
        checker.add_consistency_rule(
            rule='check-attribute',
            rule_params={'attribute': self.biosignal_att_key}
        )
        checker.add_consistency_rule(
            rule='check-attribute-type',
            rule_params={'attribute': self.biosignal_att_key,
                         'type': meeg.EEG}
        )
        # Check channels
        checker.add_consistency_rule(
            rule='check-values-in-attribute',
            rule_params={'attribute': 'channels',
                         'values': self.channel_set.channels},
            parent=self.biosignal_att_key + '.channel_set'
        )
        # Check sample rate
        if self.fs is not None:
            checker.add_consistency_rule(rule='check-attribute-value',
                                         rule_params={'attribute': 'fs',
                                                      'value': self.fs},
                                         parent=self.biosignal_att_key)
        else:
            warnings.warn('Parameter fs is None. The consistency of the '
                          'dataset cannot be assured. Still, you can use '
                          'target_fs parameter for feature extraction '
                          'and everything should be fine.')
        # Check experiment
        checker.add_consistency_rule(
            rule='check-attribute',
            rule_params={'attribute': self.experiment_att_key}
        )
        checker.add_consistency_rule(
            rule='check-attribute-type',
            rule_params={'attribute': self.experiment_att_key,
                         'type': MIData}
        )
        # Check mode
        # if self.experiment_mode is not None:
        #     checker.add_consistency_rule(
        #         rule='check-attribute-value',
        #         rule_params={'attribute': 'mode',
        #                      'value': self.experiment_mode},
        #         parent=self.experiment_att_key
        #     )
        # Check track_attributes
        if self.track_attributes is not None:
            for key, value in self.track_attributes.items():
                checker.add_consistency_rule(
                    rule='check-attribute',
                    rule_params={'attribute': key},
                    parent=value['parent']
                )
                if value['track_mode'] == 'concatenate':
                    checker.add_consistency_rule(
                        rule='check-attribute-type',
                        rule_params={'attribute': key,
                                     'type': [list, np.ndarray]},
                        parent=value['parent']
                    )
        return checker
[docs]    def custom_operations_on_recordings(self, recording):
        # Select channels
        eeg = getattr(recording, self.biosignal_att_key)
        eeg.change_channel_set(self.channel_set)
        return recording
[docs]class StandardPreprocessing(components.ProcessingMethod):
    """Just the common preprocessing applied in MI-based BCI. Simple,
    quick and effective: frequency IIR filter followed by common average
    reference (CAR) spatial filter.
    """
[docs]    def __init__(self, order=5, cutoff=[0.05, 63], btype='bandpass',
                 temp_filt_method='sosfiltfilt'):
        super().__init__(fit_transform_signal=['signal'],
                         fit_transform_dataset=['dataset'])
        # Parameters
        self.order = order
        self.cutoff = cutoff
        self.btype = btype
        self.temp_filt_method = temp_filt_method
        # Variables that
        self.iir_filter = None
[docs]    def fit(self, fs, n_cha=None):
        """Fits the IIR filter.
        Parameters
        ----------
        fs: float
            Sample rate of the signal.
        n_cha: int
            Number of channels. Used to compute the initial conditions of the
            frequency filter. Only required with sosfilt filtering method
            (online filtering)
        """
        self.iir_filter = IIRFilter(order=self.order,
                                    cutoff=self.cutoff,
                                    btype=self.btype,
                                    filt_method=self.temp_filt_method)
        self.iir_filter.fit(fs, n_cha=n_cha)
[docs]    def transform_signal(self, signal):
        """Transforms an EEG signal applying IIR filtering and CAR sequentially
        Parameters
        ----------
        signal: np.array or list
            Signal to transform. Shape [n_samples x n_channels]
        """
        signal = self.iir_filter.transform(signal)
        signal = car(signal)
        return signal
[docs]    def fit_transform_signal(self, signal, fs):
        """Fits the IIR filter and transforms an EEG signal applying IIR
        filtering and CAR sequentially
        Parameters
        ----------
        signal: np.array or list
            Signal to transform. Shape [n_samples x n_channels]
        fs: float
            Sample rate of the signal.
        """
        self.iir_filter = IIRFilter(order=self.order,
                                    cutoff=self.cutoff,
                                    btype=self.btype,
                                    filt_method=self.temp_filt_method)
        signal = self.iir_filter.fit_transform(signal, fs)
        signal = car(signal)
        return signal
[docs]    def fit_transform_dataset(self, dataset, show_progress_bar=True):
        """Fits the IIR filter and transforms an EEG signal applying the
        filter and CAR sequentially. Each recording is preprocessed
        independently, taking into account possible differences in sample rate.
        Parameters
        ----------
        dataset: MIDataset
            MIDataset with the recordings to be preprocessed.
        show_progress_bar: bool
            Show progress bar
        """
        pbar = None
        if show_progress_bar:
            pbar = tqdm(total=len(dataset.recordings),
                        desc='Preprocessing')
        for rec in dataset.recordings:
            eeg = getattr(rec, dataset.biosignal_att_key)
            eeg.signal = self.fit_transform_signal(eeg.signal, eeg.fs)
            setattr(rec, dataset.biosignal_att_key, eeg)
            if show_progress_bar:
                pbar.update(1)
        if show_progress_bar:
            pbar.close()
        return dataset
[docs]class StandardFeatureExtraction(components.ProcessingMethod):
    """Standard feature extraction method for MI-based spellers. Basically,
    it gets the raw epoch for each MI event.
    """
[docs]    def __init__(self, w_epoch_t=(0, 3000), target_fs=128,
                 baseline_mode='trial',
                 w_baseline_t=(-1500, -500), norm='z',
                 concatenate_channels=False, safe_copy=True):
        """Class constructor
        w_epoch_t : list
            Temporal window in ms for each epoch relative to the event onset
            (e.g., [0, 3000])
        target_fs : float of None
            Target sample rate of each epoch. If None, all the recordings must
            have the same sample rate, so it is strongly recommended to set this
            parameter to a suitable value to avoid problems and save time
        baseline_mode : {'run', 'trial', None}
            If "run" selected it will use the w_baseline_t of the beginning of
            each run.
            If "trial" selected it will use the w_baseline_t previous to each
            epoch extracted.
            If None it will not perform the baseline extraction.
        w_baseline_t : list, np.ndarray or "auto"
            Temporal window in ms to be used for baseline normalization.
            If baseline_mode = "run" it will use the temporal window since start
            of trial.
            If baseline_mode = "trial", for each epoch it will use the time
            relative to the event onset (e.g., [-1500, -500]).
            If "auto" selected:
                If baseline_mode = "run" will use the time between the
                MedusaData.experiment.start and MedusaData.experiment.calibration
                information stored in the run.
                If baseline_mode = "trial" will use the time between
                [-1500,-500]ms takes the baseline from -1500ms to -500ms before
                each onset (0 ms represents the onset).
                If baseline_mode = None it will not be used.
        norm : str {'z'|'dc'}
            Type of baseline normalization. Set to 'z' for Z-score normalization
            or 'dc' for DC normalization
        concatenate_channels : bool
            This parameter controls the shape of the feature array. If True, all
            channels will be concatenated, returning an array of shape [n_events
            x (samples x channels)]. If false, the array will have shape
            [n_events x samples x channels]
        safe_copy : bool
            Makes a safe copy of the signal to avoid changing the original
            samples due to references
        """
        super().__init__(transform_signal=['x'],
                         transform_dataset=['x', 'x_info'])
        self.w_epoch_t = w_epoch_t
        self.target_fs = target_fs
        self.baseline_mode = baseline_mode
        self.w_baseline_t = w_baseline_t
        self.norm = norm
        self.concatenate_channels = concatenate_channels
        self.safe_copy = safe_copy
[docs]    def transform_signal(self, times, signal, fs, onsets):
        """Function to extract MI features from raw signal. It returns a 3D
        feature array with shape [n_events x n_samples x n_channels]. This
        function does not track any other attributes. Use for online processing
        and custom higher level functions.
        Parameters
        ----------
        times : list or numpy.ndarray
                1D numpy array [n_samples]. Timestamps of each sample. If they
                are not available, generate them artificially. Nevertheless,
                all signals and events must have the same temporal origin
        signal : list or numpy.ndarray
            2D numpy array [n_samples x n_channels]. EEG samples (the units
            should be defined using kwargs)
        fs : int or float
            Sample rate of the recording.
        onsets : list or numpy.ndarray [n_events x 1]
                Timestamp of each event
        Returns
        -------
        features : np.ndarray [n_events x n_samples x n_channels]
            Feature array with the epochs of signal
        """
        # Avoid changes in the original signal (this may not be necessary)
        if self.safe_copy:
            signal = signal.copy()
        # Baseline options
        norm = self.norm
        assert self.baseline_mode in ('run', 'trial', None), \
            ValueError('Parameter baseline_mode must be {"run", "trial", None}')
        if self.baseline_mode is None:
            assert self.w_baseline_t is None, 'If baseline_mode is None, ' \
                                              'parameter w_baseline_t must be ' \
                                              'None'
            w_baseline_trial_t = None
        elif self.baseline_mode == 'trial':
            if self.w_baseline_t == 'auto':
                w_baseline_trial_t = self.w_epoch_t
            else:
                w_baseline_trial_t = self.w_baseline_t
        elif self.baseline_mode == 'run':
            norm = None
            w_baseline_trial_t = None
        # Extract features
        features = get_epochs_of_events(timestamps=times, signal=signal,
                                        onsets=onsets, fs=fs,
                                        w_epoch_t=self.w_epoch_t,
                                        w_baseline_t=w_baseline_trial_t,
                                        norm=norm)
        if self.baseline_mode == "run":
            if self.w_baseline_t == 'auto':
                norm_epoch_t = [0, 5000]
            else:
                norm_epoch_t = self.w_baseline_t
            norm_epoch_s = np.array(norm_epoch_t * fs / 1000, dtype=int)
            norm_epoch = np.expand_dims(signal[norm_epoch_s], axis=0)
        else:
            norm_epoch = None
        features = normalize_epochs(features, norm_epochs=norm_epoch,
                                    norm=self.norm)
        # Resample each epoch to the target frequency
        if self.target_fs is not None:
            if self.target_fs > fs:
                raise warnings.warn('Target fs is greater than data fs')
            features = resample_epochs(features,
                                       self.w_epoch_t,
                                       self.target_fs)
        # Reshape epochs and concatenate the channels
        if self.concatenate_channels:
            features = np.squeeze(features.reshape((features.shape[0],
                                                    features.shape[1] *
                                                    features.shape[2], 1)))
        return features
[docs]    def transform_dataset(self, dataset, show_progress_bar=True):
        #TODO: Review with modifications of Eduardo in erp_spellers (SERGIO)
        """High level function to easily extract features from EEG recordings
        and save useful info for later processing. Nevertheless, the provided
        functionality has several limitations and it will not be suitable for
        all cases and processing pipelines. If it does not fit your needs,
        create a custom function iterating the recordings and using
        extract_erp_features, a much more low-level and general function. This
        function does not apply any preprocessing to the signals, this must
        be done before
        Parameters
        ----------
        dataset: MIDataset
            List of data_structures.Recordings or data_structures.Dataset. If this
            parameter is a list of recordings, the consistency of the dataset will
            be checked. Otherwise, if the parameter is a dataset, this function
            assumes that the consistency is already checked
        show_progress_bar: bool
            Show progress bar
        Returns
        -------
        features : numpy.ndarray
            Array with the biosignal samples arranged in epochs
        track_info : dict
            Dictionary with tracked information across all recordings
        """
        # Avoid changes in the original recordings (this may not be necessary)
        if self.safe_copy:
            dataset = copy.deepcopy(dataset)
        # Avoid consistency problems
        if dataset.fs is None and self.target_fs is None:
            raise ValueError('The consistency of the features is not assured '
                             'since dataset.fs and target_fs are both None. '
                             'Specify one of these parameters')
        # Additional track attributes
        track_attributes = dataset.track_attributes
        # track_attributes['run_idx'] = {
        #     'track_mode': 'concatenate',
        #     'parent': dataset.experiment_att_key
        # }
        # Initialization
        features = None
        track_info = dict()
        for key, value in track_attributes.items():
            if value['track_mode'] == 'append':
                track_info[key] = list()
            elif value['track_mode'] == 'concatenate':
                track_info[key] = None
            else:
                raise ValueError('Unknown track mode')
        # Init progress bar
        pbar = None
        if show_progress_bar:
            pbar = tqdm(total=len(dataset.recordings),
                        desc='Extracting features')
        # Compute features
        for run_counter, rec in enumerate(dataset.recordings):
            # Extract recording experiment and biosignal
            rec_exp = getattr(rec, dataset.experiment_att_key)
            rec_sig = getattr(rec, dataset.biosignal_att_key)
            # Get features
            rec_feat = self.transform_signal(
                times=rec_sig.times,
                signal=rec_sig.signal,
                fs=rec_sig.fs,
                onsets=rec_exp.onsets
                # ,
                # calibration_t=rec_exp.calibration_t
            )
            features = np.concatenate((features, rec_feat), axis=0) \
                if features is not None else rec_feat
            # Special attributes that need tracking across runs to assure the
            # consistency of the dataset
            # rec_exp.run_idx = run_counter * np.ones_like(rec_exp.trial_idx)
            # Track experiment info
            for key, value in track_attributes.items():
                if value['parent'] is None:
                    parent = rec
                else:
                    parent = rec
                    for p in value['parent'].split('.'):
                        parent = getattr(parent, p)
                att = getattr(parent, key)
                if value['track_mode'] == 'append':
                    track_info[key].append(att)
                elif value['track_mode'] == 'concatenate':
                    track_info[key] = np.concatenate(
                        (track_info[key], att), axis=0
                    ) if track_info[key] is not None else att
                else:
                    raise ValueError('Unknown track mode')
            if show_progress_bar:
                pbar.update(1)
        if show_progress_bar:
            pbar.close()
        return features, track_info
[docs]class CSPFeatureExtraction(StandardFeatureExtraction):
    """Common Spatial Patterns (CSP) feature extraction method for MI-based
    spellers.
    Processing pipeline:
    - Use of StandardFeatureExtraction to get the raw epoch of each MI event.
    - Extract CSP features of those MI events.
    """
[docs]    def __init__(self, n_filters=4, w_epoch_t=(500, 4000), target_fs=60,
                 baseline_mode='trial',
                 w_baseline_t=(-1500, -500), norm='z',
                 concatenate_channels=False, safe_copy=True):
        """Class constructor
        n_filter : int or None
            Number of most discriminant CSP filters to decompose the signal
            into (must be less or equal to the number of channels in your
            signal).
            If int it will return that number of filters.
            If None it will return the all calculated filters.
        w_epoch_t : list
            Temporal window in ms for each epoch relative to the event onset
            (e.g., [0, 3000])
        target_fs : float of None
            Target sample rate of each epoch. If None, all the recordings must
            have the same sample rate, so it is strongly recommended to set this
            parameter to a suitable value to avoid problems and save time
        baseline_mode : {'run', 'trial', None}
            If "run" selected it will use the w_baseline_t of the beginning of
            each run.
            If "trial" selected it will use the w_baseline_t previous to each
            epoch extracted.
            If None it will not perform the baseline extraction.
        w_baseline_t : list, np.ndarray or "auto"
            Temporal window in ms to be used for baseline normalization.
            If baseline_mode = "run" it will use the temporal window since start
            of trial.
            If baseline_mode = "trial", for each epoch it will use the time
            relative to the event onset (e.g., [0, 3000]).
            If "auto" selected:
                If baseline_mode = "run" will use the time between the
                MedusaData.experiment.start and MedusaData.experiment.calibration
                information stored in the run.
                If baseline_mode = "trial" will use the time between
                [-1500,-500]ms takes the baseline from -1500ms to -500ms before
                each onset (0 ms represents the onset).
                If baseline_mode = None it will not be used.
        norm : str {'z'|'dc'}
            Type of baseline normalization. Set to 'z' for Z-score normalization
            or 'dc' for DC normalization
        concatenate_channels : bool
            This parameter controls the shape of the feature array. If True, all
            channels will be concatenated, returning an array of shape [n_events
            x (samples x channels)]. If false, the array will have shape
            [n_events x samples x channels]
        safe_copy : bool
            Makes a safe copy of the signal to avoid changing the original
            samples due to references
        After using the function self.fit(), the attributes are computed:
        Attributes
        ----------
        CSP : CSP class with attributes filters, eigenvalues, patterns and
            methods fit and project.
            filters : {(…, M, M) numpy.ndarray, (…, M, M) matrix}
                Mixing matrix (spatial filters are stored in columns).
            eigenvalues : (…, M) numpy.ndarray
                Eigenvalues of w.
            patterns : numpy.ndarray
                De-mixing matrix (activation patterns are stored in columns).
        """
        super().__init__(w_epoch_t=w_epoch_t, target_fs=target_fs,
                         baseline_mode=baseline_mode,
                         w_baseline_t=w_baseline_t, norm=norm,
                         concatenate_channels=concatenate_channels,
                         safe_copy=safe_copy)
        self.CSP = CSP(n_filters=n_filters)
[docs]    def fit(self, X, y):
        """Train Common Spatial Patterns (CSP) filters
        Train Common Spatial Patterns (CSP) filters with support to >2
        classes based on support multiclass CSP by means of approximate
        joint diagonalization. In this case, the spatial filter selection
        is achieved according to [1].
        Adapted from http://github.com/alexandrebarachant/pyRiemann
        Parameters
        ----------
        X : numpy.ndarray, [n_trials, samples, channels]
            Epoched data of shape (n_trials, samples, channels)
        y : numpy.ndarray, [n_trials,]
            Labels for epoched data of shape (n_trials,)
        References
        ----------
        [1] Grosse-Wentrup, Moritz, and Martin Buss. "Multiclass common
        spatial patterns and information theoretic feature extraction."
        Biomedical Engineering, IEEE Transactions on 55, no. 8 (2008):
        1991-2000.
        """
        self.CSP.fit(X=X, y=y)
[docs]    def project(self, times, signal, fs, onsets):
        """Function to extract MI features from raw signal. It returns a 3D
        feature array with shape [n_events x n_samples x n_channels]. This
        function does not track any other attributes. Use for online processing
        and custom higher level functions.
        Parameters
        ----------
        times : list or numpy.ndarray
                1D numpy array [n_samples]. Timestamps of each sample. If they
                are not available, generate them artificially. Nevertheless,
                all signals and events must have the same temporal origin
        signal : list or numpy.ndarray
            2D numpy array [n_samples x n_channels]. EEG samples (the units
            should be defined using kwargs)
        fs : int or float
            Sample rate of the recording.
        onsets : list or numpy.ndarray [n_events x 1]
                Timestamp of each event
        Returns
        -------
        features : numpy.ndarray [n_events x n_samples x n_channels]
            Feature array with the epochs of signal projected in the CSP space
        """
        features = super().transform_signal(times, signal, fs, onsets)
        features = self.CSP.project(X=features)
        return features
[docs]    def extract_log_var_features(self, X):
        """This method computes the standard motor imagery log-variance features
        given the CSP .
        Parameters
        ----------
        X : numpy.ndarray, [n_trials, samples, channels]
            Epoched data of shape (n_trials, samples, channels)
        Returns
        -------
        features : numpy.ndarray [n_events x n_channels]
            Feature array with the epochs of signal projected in the CSP space
        """
        features = self.CSP.project(X=X)
        # Obtain log-variance features given the CSP features
        features = np.log(np.var(features, axis=-1))
        return features
[docs]    def transform_dataset(self, dataset, show_progress_bar=True):
        features, track_info = super().transform_dataset(
            dataset=dataset, show_progress_bar=show_progress_bar)
        # Fit CSP filter
        self.fit(X=features, y=track_info['mi_labels'])
        # Project features into CSP space and obtain log-variance features given
        # the CSP features
        features = self.extract_log_var_features(X=features)
        return features, track_info
[docs]class MIModel(components.Algorithm):
    """Skeleton class for MI-based BCIs models. This class inherits from
    components.Algorithm. Therefore, it can be used to create standalone
    algorithms that can be used in compatible apps from medusa-platform
    for online experiments. See components.Algorithm to know more about this
    functionality.
    Related tutorials:
        - Overview of mi_paradigms module [LINK]
        - Create standalone models for MI-based BCIs compatible with
            Medusa platform [LINK]
    """
[docs]    def __init__(self):
        """Class constructor
        """
        print('MIModel')
        super().__init__(fit_dataset=['mi_target', 'mi_result', 'accuracy'],
                         predict=['mi_result'])
        # Settings
        self.settings = None
        self.channel_set = None
        self.configure()
        # Configuration
        self.is_configured = False
        self.is_built = False
        self.is_fit = False
[docs]    @abstractmethod
    def configure(self, **kwargs):
        """This function must be used to configure the model before calling
        build method. Class attribute settings attribute must be set with a dict
        """
        # Update state
        self.is_configured = True
        self.is_built = False
        self.is_fit = False
[docs]    @abstractmethod
    def build(self, *args, **kwargs):
        """This function builds the model, adding all the processing methods
        to the pipeline. It must be called after configure.
        """
        # Check errors
        if not self.is_configured:
            raise ValueError('Function configure must be called first!')
        # Update state
        self.is_built = True
        self.is_fit = False
[docs]    def fit_dataset(self, dataset, **kwargs):
        """Function that receives an MIDataset and uses its data to
        fit the model. By default, executes pipeline 'fit_dataset'. Override
        method for other behaviour.
        Parameters
        ----------
        dataset: MIDataset
            Dataset with recordings from an MI-based BCI experiment
        kwargs: key-value arguments
            Optional parameters depending on the specific implementation of
            the model
        Returns
        -------
        fit_results: dict
            Dict with the information of the fit process. For command
            decoding models, at least it has to contain keys
            mi_target, mi_result and accuracy, which contain the target MI,
            the decoded MI and the decoding accuracy in the analysis.
        """
        # Check errors
        if not self.is_built:
            raise ValueError('Function build must be called first!')
        # Execute pipeline
        output = self.exec_pipeline('fit_dataset', dataset=dataset)
        # Set channels
        self.channel_set = dataset.channel_set
        # Update state
        self.is_fit = True
        return output
[docs]    def predict(self, times, signal, fs, l_cha, x_info, **kwargs):
        """Function that receives EEG signal and experiment info from an
        MI-based trial to decode the user's intentions. Used in online
        experiments. By default, executes pipeline 'predict'. Override method
        for other behaviour.
        Parameters
        ---------
        times: list or numpy.ndarray
            Timestamps of the EEG samples
        signal: list or numpy.ndarray
            EEG samples with shape [n_samples x n_channels]
        fs: float
            Sample rate of the EEG signal
        l_cha: list
            List of channel labels
        x_info: dict
            Dict with the needed experiment info to decode the commands. It
            has to contain keys: mode, onsets, w_trial_t.
            See MIData to know how are defined these variables.
        kwargs: key-value arguments
            Optional parameters depending on the specific implementation of
            the model
        """
        # Check errors
        if not self.is_fit:
            raise ValueError('Function fit_dataset must be called first!')
        # Check channels
        self.channel_set.check_channels_labels(l_cha, strict=True)
        # Execute predict pipeline
        return self.exec_pipeline('predict', times=times, signal=signal,
                                  fs=fs, x_info=x_info, **kwargs)
[docs]class MIModelCSP(MIModel):
    """Decoding model for MI-based BCI applications based on
    Common Spatial Patterns (CSP).
    Dataset features:
    - Sample rate of the signals > 60 Hz. The model can handle recordings
        with different sample rates.
    - Recommended channels: ['C3', 'C4'].
    Processing pipeline:
    - Preprocessing (medusa.bci.erp_spellers.StandardPreprocessing):
        - IIR Filter (order=5, cutoff=(8, 30) Hz: unlike FIR filters, IIR
            filters are quick and can be applied in small signal chunks. Thus,
            they are the preferred method for frequency filter in online
            systems.
        - Common average reference (CAR): widely used spatial filter that
            increases the signal-to-noise ratio of the MI control signals.
    - Feature extraction (medusa.bci.mi_paradigms.CSPFeatureExtraction):
        - Epochs (window=(500, 4000) ms, resampling to 60 HZ): the epochs of
            signal are extracted for each stimulation. Baseline normalization
            is also applied, taking the window (-1500, -500) ms relative to the
            stimulus onset.
        - CSP projection: Epochs are then projected according to a CSP filter
            previously trained.
    - Feature classification (
    sklearn.discriminant_analysis.LinearDiscriminantAnalysis)
        - Regularized linear discriminant analysis (rLDA): we use the sklearn
            implementation, with eigen solver and auto shrinkage paramers. See
            reference in sklearn doc.
    """
[docs]    def configure(self, p_filt_cutoff=(8, 30), f_w_epoch_t=(500, 4000),
                  f_target_fs=60):
        self.settings = {
            'p_filt_cutoff': p_filt_cutoff,
            'f_w_epoch_t': f_w_epoch_t,
            'f_target_fs': f_target_fs
        }
        # Update state
        self.is_configured = True
        self.is_built = False
        self.is_fit = False
[docs]    def build(self):
        # Check errors
        if not self.is_configured:
            raise ValueError('Function configure must be called first!')
        # Preprocessing (default: bandpass IIR filter [8, 30] Hz + CAR)
        self.add_method('prep_method', StandardPreprocessing(
            cutoff=self.settings['p_filt_cutoff']
        ))
        # Feature extraction (default: epochs [500, 4000] ms + resampling to 80
        # Hz)
        self.add_method('ext_method', CSPFeatureExtraction(
            w_epoch_t=self.settings['f_w_epoch_t'],
            target_fs=self.settings['f_target_fs'],
        ))
        # Feature classification (rLDA)
        clf = components.ProcessingClassWrapper(
            LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto'),
            fit=[], predict_proba=['y_pred']
        )
        self.add_method('clf_method', clf)
        # Update state
        self.is_built = True
        self.is_fit = False
[docs]    def fit_dataset(self, dataset, **kwargs):
        # Check errors
        if not self.is_built:
            raise ValueError('Function build must be called first!')
        # Preprocessing
        dataset = self.get_inst('prep_method').fit_transform_dataset(dataset)
        # Extract CSP features
        x, x_info = self.get_inst('ext_method').transform_dataset(dataset)
        # Classification
        self.get_inst('clf_method').fit(x, x_info['mi_labels'])
        y_proba = self.get_inst('clf_method').predict_proba(x)
        y_pred = self.get_inst('clf_method').predict(x)
        # Accuracy
        accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
        clf_report = classification_report(x_info['mi_labels'], y_pred,
                                           output_dict=True)
        assessment = {
            'x': x,
            'x_info': x_info,
            'y_proba': y_proba,
            'y_pred': y_pred,
            'accuracy': accuracy,
            'report': clf_report
        }
        # Save info
        self.channel_set = dataset.channel_set
        # Update state
        self.is_fit = True
        return assessment
[docs]    def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
        # Check errors
        if not self.is_fit:
            raise ValueError('Function fit_dataset must be called first!')
        # Check channel set
        if self.channel_set.channels != channel_set.channels:
            warnings.warn('The channel set is not the same that was used to '
                          'fit the model. Be careful!')
        # Preprocessing
        signal = self.get_inst('prep_method').fit_transform_signal(signal, fs)
        # Extract features
        x = self.get_inst('ext_method').transform_signal(times, signal, fs,
                                                         x_info['onsets'])
        x = self.get_inst('ext_method').extract_log_var_features(x)
        # Classification
        y_proba = self.get_inst('clf_method').predict_proba(x)
        y_pred = self.get_inst('clf_method').predict(x)
        # Decoding
        accuracy = None
        clf_report = None
        if x_info['mi_labels'] is not None:
            accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
            clf_report = classification_report(x_info['mi_labels'], y_pred,
                                               output_dict=True)
        decoding = {
            'x': x,
            'x_info': x_info,
            'y_proba': y_proba,
            'y_pred': y_pred,
            'accuracy': accuracy,
            'report': clf_report
        }
        return decoding
[docs]class MIModelEEGSym(MIModel):
    """Decoding model for MI-based BCI applications based on EEGSym [1], a deep
    convolutional neural network developed for inter-subjects MI classification.
    Dataset features:
    - Sample rate of the signals > 128 Hz. The model can handle recordings
        with different sample rates.
    - Recommended channels: ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4'].
    Processing pipeline:
    - Preprocessing:
        - IIR Filter (order=4, lowpass=49 Hz: unlike FIR filters, IIR filters
            are quick and can be applied in small signal chunks. Thus,
            they are the preferred method for frequency filter in online systems
        - Common average reference (CAR): widely used spatial filter that
            increases the signal-to-noise ratio.
    - Feature extraction:
        - Epochs (window=(0, 2000) ms, resampling to 128 HZ): the epochs of
            signal are extracted after each onset. Baseline normalization
            is also applied, taking the same epoch window.
    - Feature classification
        - EEGSym: convolutional neural network [1].
    References
    ----------
    [1] Pérez-Velasco, S., Santamaría-Vázquez, E., Martínez-Cagigal, V.,
    Marcos-Mateo, D., & Hornero, R. (2020). EEGSym: Overcoming Intersubject
    Variability in Motor Imagery Based BCIs with Deep Learning. ?.
    """
[docs]    def configure(self, cnn_n_cha=8, ch_lateral=3, fine_tuning=False,
                  shuffle_before_fit=True, validation_split=0.4,
                  init_weights_path=None, gpu_acceleration=False,
                  augmentation=False):
        self.settings = {
            'cnn_n_cha': cnn_n_cha,
            'ch_lateral': ch_lateral,
            'fine_tuning': fine_tuning,
            'augmentation': augmentation,
            'shuffle_before_fit': shuffle_before_fit,
            'validation_split': validation_split,
            'init_weights_path': init_weights_path,
            'gpu_acceleration': gpu_acceleration
        }
        # Update state
        self.is_configured = True
        self.is_built = False
        self.is_fit = False
[docs]    def build(self):
        # Check errors
        if not self.is_configured:
            raise ValueError('Function configure must be called first!')
        # Only import deep learning models if necessary
        from medusa.deep_learning_models import EEGSym
        # Preprocessing (bandpass IIR filter [0.5, 45] Hz + CAR)
        self.add_method('prep_method',
                        StandardPreprocessing(cutoff=49, btype='lowpass'))
        # Feature extraction (epochs [0, 2000] ms + resampling to 128 Hz)
        self.add_method('ext_method', StandardFeatureExtraction(w_epoch_t=(
            0, 2000), target_fs=128, w_baseline_t=(0, 2000),))
        # Feature classification
        clf = EEGSym(
            input_time=2000,
            fs=128,
            n_cha=self.settings['cnn_n_cha'],
            ch_lateral=self.settings['ch_lateral'],
            filters_per_branch=24,
            scales_time=(125, 250, 500),
            dropout_rate=0.4,
            activation='elu', n_classes=2,
            learning_rate=0.0001,
            gpu_acceleration=self.settings['gpu_acceleration'])
        self.is_fit = False
        if self.settings['init_weights_path'] is not None:
            clf.load_weights(self.settings['init_weights_path'])
            self.channel_set = meeg.EEGChannelSet()
            standard_lcha = ['F7', 'C3', 'Po3', 'Cz', 'Pz', 'F8', 'C4', 'Po4']
            self.channel_set.set_standard_montage(standard_lcha)
            self.is_fit = True
        self.add_method('clf_method', clf)
        # Update state
        self.is_built = True
        # self.is_fit = False
[docs]    def fit_dataset(self, dataset, **kwargs):
        # Check errors
        if not self.is_built:
            raise ValueError('Function build must be called first!')
        # Preprocessing
        dataset = self.get_inst('prep_method').fit_transform_dataset(dataset)
        # Extract CSP features
        x, x_info = self.get_inst('ext_method').transform_dataset(dataset)
        # Put channels in symmetric order
        x = self.get_inst('clf_method').symmetric_channels(x,
                                                           dataset.channel_set.l_cha)
        # Classification
        self.get_inst('clf_method').fit(x, x_info['mi_labels'],
                                        fine_tuning=self.settings['fine_tuning'],
                                        shuffle_before_fit=self.settings['shuffle_before_fit'],
                                        validation_split=self.settings['validation_split'],
                                        augmentation=self.settings['augmentation'],
                                        **kwargs)
        y_prob = self.get_inst('clf_method').predict_proba(x)
        y_pred = y_prob.argmax(axis=-1)
        # Accuracy
        accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
        clf_report = classification_report(x_info['mi_labels'], y_pred,
                                           output_dict=True)
        assessment = {
            'x': x,
            'x_info': x_info,
            'y_pred': y_pred,
            'y_prob': y_prob,
            'accuracy': accuracy,
            'report': clf_report
        }
        # Save info
        self.channel_set = dataset.channel_set
        # Update state
        self.is_fit = True
        return assessment
[docs]    def predict(self, times, signal, fs, channel_set, x_info, **kwargs):
        # Check errors
        if not self.is_fit:
            raise ValueError('Function fit_dataset must be called first!')
        # Check channel set
        if self.channel_set.channels != channel_set.channels:
            warnings.warn('The channel set is not the same that was used to '
                          'fit the model. Be careful!')
        # Preprocessing
        signal = self.get_inst('prep_method').fit_transform_signal(signal, fs)
        # Extract features
        x = self.get_inst('ext_method').transform_signal(times, signal, fs,
                                                         x_info['onsets'])
        # Put channels in symmetric order
        x = self.get_inst('clf_method').symmetric_channels(x,
                                                           self.channel_set.l_cha)
        # Classification
        y_prob = self.get_inst('clf_method').predict_proba(x)
        y_pred = y_prob.argmax(axis=-1)
        # Decoding
        accuracy = None
        clf_report = None
        if x_info['mi_labels'] is not None:
            accuracy = np.sum((y_pred == x_info['mi_labels'])) / len(y_pred)
            clf_report = classification_report(x_info['mi_labels'], y_pred,
                                               output_dict=True)
        decoding = {
            'x': x,
            'x_info': x_info,
            'y_pred': y_pred,
            'y_prob': y_prob,
            'accuracy': accuracy,
            'report': clf_report
        }
        return decoding