Source code for medusa.bci.nft_paradigms

"""
In this module you will find useful functions and classes to apply on-line
Neurofeedback models. Each model is based on different features to be used
as target to train. Enjoy!

@author: Diego Marcos-Martínez
"""

# Built-in imports
from abc import ABC, abstractmethod
import concurrent

# External imports
import numpy as np
import scipy.signal

# Medusa imports
import medusa as mds
from medusa import components
from medusa import meeg
from medusa.spatial_filtering import LaplacianFilter, car
from medusa.connectivity.phase_connectivity import phase_connectivity
from medusa.connectivity.amplitude_connectivity import __aec_ort_cpu as aec
from medusa.graph_theory import degree
from medusa.artifact_removal import reject_noisy_epochs
from medusa.epoching import get_epochs_of_events
from medusa.local_activation.spectral_parameteres import absolute_band_power


[docs]class SignalPreprocessing(components.ProcessingMethod): """ Common preprocessing applied in Neurofeedback applications. It is composed by a frequency IIR filter followed by a spatial filters. Functions are adapted to filter the signal in more than one frequency range, if necessary. """
[docs] def __init__(self, filter_dict=None, montage=None, target_channels=None, laplacian=False, car=False, n_cha_lp=None): super().__init__(prep_fit_transform=['signal'], prep_transform=['signal'], narrow_transform=['signal']) # Error check if not filter_dict: raise ValueError('[SignalPreprocessing] Filter dict parameter ' '"filter_dict" must be a list containing all ' 'necessary information to perform the filtering!. ' 'The information should be: type and cutoff') for filter in filter_dict: if not isinstance(filter, dict): raise ValueError('[SignalPreprocessing] Each filter must ' 'be a dict()!') if 'cutoff' not in filter or 'type' not in filter: raise ValueError('[SignalPreprocessing] Each filter must ' 'be a dict() containing the following keys: ' '"cutoff"and "type"!') if filter['type'] != 'training' and filter['type'] != 'artifact': raise ValueError('[SignalPreprocessing] "type" must be "training"' 'or "artifact".') if not montage: raise ValueError('[SignalPreprocessing] Pre-processing parameter' '"montage" must be a dict containing all' 'labels of channels and montage standard key') if laplacian and target_channels is None: raise ValueError('[SignalPreprocessing] Laplacian filter needs to ' 'define "target_channels" parameter.') # Parameters self.filter_dict = filter_dict self.l_cha = montage.l_cha self.target_channels = target_channels self.montage = montage self.perform_car = car self.perform_laplacian = laplacian # Variables self.artifact_iir_filters = [] self.target_iir_filters = [] self.offset_line_removal = None self.laplacian_filter = None
[docs] def fit(self, fs): """ Fits the IIR filter and Laplacian spatial filter (if selected) for signal preprocessing stage. Parameters ---------- fs: float Sampling rate in Hz. """ # Fit Spectral Filters (Predefined to be optimal) self.offset_line_removal = mds.IIRFilter(order=3, cutoff=[0.5, 40], btype='bandpass', filt_method='sosfiltfilt') self.offset_line_removal.fit(fs, len(self.l_cha)) # Define filters for filtering over epochs for filter in self.filter_dict: for f in filter['cutoff']: iir = mds.IIRFilter(order=1, cutoff=f, btype='bandpass', filt_method='sosfiltfilt') if self.target_channels is None: iir.fit(fs, len(self.l_cha)) else: iir.fit(fs, len(self.target_channels)) if filter['type'] == 'artifact': self.artifact_iir_filters.append(iir) elif filter['type'] == 'training': self.target_iir_filters.append(iir) # Fit Laplacian Filter if self.perform_laplacian: if len(self.montage.l_cha) >= 5: self.laplacian_filter = LaplacianFilter(self.montage, mode='auto') self.laplacian_filter.fit_lp(l_cha_to_filter=self.target_channels)
[docs] def prep_transform(self, signal, parallel_computing=True): """ Transforms an EEG signal applying IIR filter. It also applies CAR and Laplacian spatial filter sequentially if desired. Parameters ---------- signal: list or numpy.ndarray Signal to transform. Shape [n_samples x n_channels] parallel_computing: bool If true, it filters the signal concurrently Returns ------- signal_: numpy.ndarray Original signal with power line and offset removed, and spatially filtered if chosen. [n_samples, n_channels]. signal_artifacts: numpy.ndarray If it has been chosen to reject artifact sections, this matrix contains the filtered signal in the frequency bands associated with these artifacts. [n_artifact_bands, n_samples, n_channels]. """ # Initialize variable n_samples = signal.shape[0] if len(self.artifact_iir_filters) == 0: signal_artifacts = None else: if self.target_channels is None: signal_artifacts = np.empty( ( len(self.artifact_iir_filters), n_samples, len(self.l_cha))) else: signal_artifacts = np.empty( (len(self.artifact_iir_filters), n_samples, len(self.target_channels))) signal_ = self.offset_line_removal.transform(signal) # Spatial filtering if self.perform_car: signal_ = car(signal_) if self.perform_laplacian: # Check if surface laplacian filter cannot be performed if self.laplacian_filter is not None: signal_ = self.laplacian_filter.apply_lp(signal_) else: signal_ = signal_[:,self.montage.get_cha_idx_from_labels(self.target_channels)] signal__ = signal_.copy() if signal_artifacts is not None: # Filter only the target channels in Power Based models if signal__.shape[1] != signal_artifacts.shape[2]: signal__ = signal__[:,self.montage.get_cha_idx_from_labels( self.target_channels)] # Frequency filtering on artifact-related bands if parallel_computing: filt_threads = [] for filter in self.artifact_iir_filters: t = components.ThreadWithReturnValue(target= filter.transform, args=(signal__,)) filt_threads.append(t) t.start() for filt_idx, thread in enumerate(filt_threads): signal_artifacts[filt_idx, :, :] = thread.join() else: for filt_idx, filter in enumerate(self.artifact_iir_filters): signal_artifacts[filt_idx, :, :] = filter.transform( signal__[np.newaxis, :, :]) return signal_, signal_artifacts
[docs] def prep_fit_transform(self, fs, signal): """ Fits the IIR filter and transforms an EEG signal applying IIR filter and spatial filters sequentially. Parameters ---------- fs: float Sampling rate in Hz. n_cha_lp: int Number of nearest channels to compute Laplacian spatial filter (Auto mode). signal: list or numpy.ndarray Signal to transform. Shape [n_samples x n_channels] Returns ------- signal: numpy.ndarray Original signal with power line and offset removed, and spatially filtered if chosen. [n_samples, n_channels]. signal_artifacts: numpy.ndarray If it has been chosen to reject artifact sections, this matrix contains the filtered signal in the frequency bands associated with these artifacts. [n_artifact_bands, n_samples, n_channels]. """ self.fit(fs) return self.prep_transform(signal)
[docs] def narrow_transform(self, signal, parallel_computing=True): """ Applies the IIR filter for narrow band filtering. Parameters ---------- signal: list or numpy.ndarray Signal to transform. Shape [n_samples x n_channels]. parallel_computing: bool If true, it filters the signal concurrently. Returns ------- signal: numpy.ndarray Signal filtered in the training band [n_samples, n_channels]. """ f_signals = np.empty( (len(self.target_iir_filters), signal.shape[0], signal.shape[1])) if parallel_computing: filt_threads = [] for filter in self.target_iir_filters: t = components.ThreadWithReturnValue(target= filter.transform, args=(signal,)) filt_threads.append(t) t.start() for filt_idx, thread in enumerate(filt_threads): f_signals[filt_idx, :, :] = thread.join() else: for filt_idx, filter in enumerate(self.target_iir_filters): f_signals[filt_idx, :, :] = filter.transform(signal) return np.squeeze(f_signals)
[docs]def ignore_noisy_windows(signals, thresholds, pct_tol): """ This function check if a specific signal segment contains noise above the pre-established thresholds. Parameters ---------- signals: numpy.ndarray Array containing the signal filtered in the frequency bands associated to artifacts to avoid. [n_artifact_bands, n_samples, n_channels]. thresholds: numpy.ndarray Array containing the variance thresholds related to artifacts to avoid. pct_tol: numpy.ndarray Array containing variance increase (in percentage) tolerated. Returns ------- """ # Check if power in forbidden bands is over thresholds over_var = np.sum(np.std(signals, axis=1).mean(axis=-1) >= (1 + pct_tol) * thresholds) if over_var < 1: return True else: return False
[docs]def make_windows(signal, fs, update_feature_window, update_rate, n_cha=1, n_samp=2, k=4, reject=True): """ Parameters ---------- signal: numpy.ndarray Signal to be converted into epochs. [n_samples, n_channels]. fs: int or float Sampling rate. update_feature_window: int or float Time window taken for the calculation of the characteristic (in seconds). update_rate: int or float Feedback update time in online mode. n_cha: int Threshold number of channels meeting the rejection condition to reject the epoch. n_samp: int Threshold number of samples meeting the rejection condition to reject the epoch. k: int Standard deviation of the signal. Used in the definition of the rejection criterion. reject: bool If true, it returns the epochs that have not been reject. Else, it returns the whole windowed signal. Returns ------- good_epochs: numpy.ndarray Array containing the signal divided into epochs that are not noisy. [n_epochs, n_samples, n_channels]. ind: numpy.ndarray Array containing bools. True for epochs that were rejected and False for epochs that were not. """ if len(signal.shape) == 1: signal = signal[:, np.newaxis] # Define necessary parameters s_duration = signal.shape[0] / fs s_mean = np.mean(signal, axis=0) s_std = np.std(signal, axis=0) # Set onsets vector onsets = np.arange(0, s_duration - update_feature_window, update_rate) s_windowed = get_epochs_of_events(np.arange(0, s_duration, 1 / fs), signal, onsets, fs, [0, update_feature_window * 1000]) # Return windows without discarding if not reject: return s_windowed pct_rejected, good_epochs, idx = reject_noisy_epochs(s_windowed, s_mean, s_std, k, n_samp, n_cha) return good_epochs, idx
[docs]class ConnectivityExtraction(components.ProcessingMethod): """ Functional Connectivity-based features to extract from user's EEG. """
[docs] def __init__(self, l_baseline_t=5, fs=250, update_feature_window=2, update_rate=0.25, fc_measure=None, mode=None, montage=None, target_channels=None, pct_tol=0.9): """ Class constructor l_baseline_t: int Time employed to calculate the number of samples to obtain baseline connectivity parameter. In seconds. fs: int or float Sample rate of the recording. update_feature_window: int or float Length in seconds of the temporal window applied to calculate the feature. update_rate: int or float Feedback update time in online mode. fc_measure: str "WPLI" or "AECORT". Measure of Functional Connectivity to calculate. mode: str "Global coupling", "Strength" or "Coupling". Information extracted from adjacency matrix. montage: EEGChannelSet target_channels: list or None List containing the labels of the target channels. pct_tol: numpy.ndarray Array containing variance increase (in percentage) tolerated. """ super().__init__(ext_feature=['conn_value']) # Check errors if not montage: raise ValueError('[ConnectivityExtraction] "montage parameter"' ' must be a dict containing all' 'labels of channels and montage standard key') if fc_measure != "WPLI" and fc_measure != "AECORT": raise ValueError('[ConnectivityExtraction] Invalid functional ' 'connectivity measure. Available measures are ' '"WPLI" and "AECORT".') if mode != "Global coupling" and mode != "Strength" and mode != \ "Coupling": raise ValueError('[ConnectivityExtraction] Invalid mode. ' 'Available modes are ' '"Global coupling", "Strength" and "Coupling".') if target_channels is None: if mode == "Strength": raise UserWarning( '[ConnectivityExtraction] Using "Strength" mode' 'without defining target channels. Average strength' 'of all channels will be returned instead.') if mode == "Coupling": mode = "Global coupling" raise UserWarning( '[ConnectivityExtraction] Using "Coupling" mode' 'without defining target channels. Global coupling of all ' 'channels will be returned instead.') self.target_channels = target_channels else: self.target_channels = montage.get_cha_idx_from_labels( target_channels) self.fc_measure = fc_measure self.mode = mode self.fs = fs self.l_baseline_t = l_baseline_t self.montage = montage self.update_feature_window = update_feature_window self.update_rate = update_rate self.w_signal_samples = int(update_feature_window * self.fs) self.w_signal_samples_calibration = int((self.l_baseline_t)* self.fs) self.pct_tol = pct_tol self.thresholds = None self.baseline_value = None
[docs] def set_baseline(self, signal, signal_artifacts, filtered_signal): """ This functions establish the baseline value. Parameters ---------- signal: numpy.ndarray Original signal pre-processed (offset and power-line removed, and CAR) [n_samples, n_channels]. signal_artifacts: numpy.ndarray Array containing the signal filtered in each frequency band associated to the artifacts to avoid. [n_artifact_bands, n_samples, n_channels]. filtered_signal: numpy.ndarray Signal filtered in the narrow band for training. [n_samples, n_channels]. Returns ------- baseline_value: float """ # Delete borders to avoid effect borders dif_window = round((signal.shape[0] - self.w_signal_samples_calibration)/2) if dif_window <= 0: raise ValueError('The duration of the signal for the baseline ' 'calculation should be slightly longer than the ' 'window time taken for the baseline calculation. ') epochs_original, index = make_windows(signal[dif_window:-dif_window, :] , self.fs, self.update_feature_window, self.update_rate) filtered_epochs = make_windows( filtered_signal[dif_window:-dif_window, :], self.fs, self.update_feature_window, self.update_rate, reject=False)[~index, :, :] adj_mat = self.calculate_adj_mat(filtered_epochs) # Parallel computing baseline values filt_threads = [] baseline_values = [] for epoch_mat in adj_mat: t = components.ThreadWithReturnValue(target=self.calculate_feature, args=(epoch_mat,)) filt_threads.append(t) t.start() for filt_idx, thread in enumerate(filt_threads): baseline_values.append(thread.join()) self.baseline_value = np.mean(baseline_values) # Define artifact related thresholds if signal_artifacts is not None: self.thresholds = np.std( signal_artifacts[:, -self.w_signal_samples_calibration:, :], axis=1).mean(axis=-1) return self.baseline_value
[docs] def ext_feature(self, signal, signal_artifacts): """ Function for extracting FC values in online mode. Parameters ---------- signal: numpy.ndarray Signal filtered in the narrow band for training. [n_samples, n_channels]. signal_artifacts: numpy.ndarray or None Array containing the signal filtered in each frequency band associated to the artifacts to avoid. [n_artifact_bands, n_samples, n_channels]. Returns ------- c_value: float """ if self.baseline_value is None: raise ValueError( '[ConnectivityExtraction] Calibration not performed.') # Delete borders dif_window = round((signal.shape[0] - self.w_signal_samples)/2) if dif_window <= 0: raise ValueError('The duration of the signal for feedback ' 'calculation should be slightly longer than the ' 'window time taken for the feedback calculation. ') adj_mat = self.calculate_adj_mat(signal[dif_window:-dif_window, :]) c_value = self.calculate_feature(np.squeeze(adj_mat)) \ - self.baseline_value # Check if artifact bands are defined if signal_artifacts is not None: if ignore_noisy_windows(signal_artifacts[:,dif_window:-dif_window, self.target_channels], self.thresholds, self.pct_tol): return c_value else: return None return c_value
[docs] def calculate_adj_mat(self, signal): """ This function calculates the adjacency matrix depending on the FC mode. Parameters ---------- signal: numpy.ndarray Signal filtered in the narrow band for training. [n_epochs, n_samples, n_channels] or [n_samples, n_channels]. Returns ------- adj_mat: numpy.ndarray [n_epochs, n_channels, n_channels]. """ # Calculate adjacency matrix depending on FC measure chosen adj_mat = None if self.fc_measure == "WPLI": adj_mat = phase_connectivity(signal,'wpli') # This is under development elif self.fc_measure == "AECORT": adj_mat = aec(signal) return adj_mat
[docs] def calculate_feature(self, adj_mat): """ Calculates Graph metric from adjacency matrix. Parameters ---------- adj_mat: numpy.ndarray [n_channels, n_channels]. """ # Calculate the baseline value depending on mode chosen if self.mode == "Global coupling": tri_l_idx = np.tril_indices(adj_mat.shape[0], -1) return np.nanmean(np.asarray(adj_mat)[tri_l_idx]) elif self.mode == "Strength": if self.target_channels is None: return np.mean(degree.degree(np.asarray(adj_mat), 'CPU')) else: return np.mean(degree.degree(np.asarray(adj_mat), 'CPU')[ self.target_channels]) elif self.mode == "Coupling": return self.mean_coupling(adj_mat)
[docs] def mean_coupling(self, adj_mat): """ This function calculates the connectivity values between all the target channels and average it value. Parameters ---------- adj_mat: numpy.ndarray [n_channels, n_channels]. """ c = [] for ind, ch_ind_1 in enumerate(self.target_channels[:-1]): for ch_ind_2 in self.target_channels[ind + 1:]: c.append(np.array(adj_mat[ch_ind_1, ch_ind_2])) return np.mean(c)
[docs]class PowerExtraction(components.ProcessingMethod): """ Power-based features to extract from user's EEG. """
[docs] def __init__(self, l_baseline_t=5, fs=250, update_feature_window=2, update_rate=0.25, pct_tol = 0.9, f_dict=None, mode=None): """ Class constructor l_baseline_t: int Time employed to calculate the number of samples to obtain baseline power parameter. In seconds. fs: int or float Sample rate of the recording. update_feature_window = int Length in seconds of the temporal window applied to calculate the feature. update_rate: int or float Feedback update time in online mode. f_dict: dict Dict containing the frequency bands associated to training band and artifacts to avoid. pct_tol: numpy.ndarray Array containing variance increase (in percentage) tolerated. mode: str "single" or "ratio" """ super().__init__(band_power=['band power']) self.mode = mode self.fs = fs self.l_baseline_t = l_baseline_t self.update_feature_window = update_feature_window self.update_rate = update_rate self.w_signal_samples = int(update_feature_window * self.fs) self.w_signal_samples_calibration = int(self.l_baseline_t * self.fs) self.f_dict = f_dict self.pct_tol = pct_tol self.baseline_power = None self.thresholds = None
[docs] def set_baseline(self, signal, signal_artifacts): """ This function sets the power baseline, given the already filtered EEG containing the calibration phase. Also, takes into account the Neurofeedback training mode, so performs different baseline calculations depending on the mode set. Parameters ---------- signal: numpy.ndarray EEG already pre-processed. [n_samples, n_channels]. signal_artifacts: numpy.ndarray Signal filtered in the frequency bands associated to the artifacts to be avoided. [n_artifact_bands, n_samples, n_channels]. Returns ------ baseline_power: float """ # Delete borders dif_window = round( (signal.shape[0] - self.w_signal_samples_calibration) / 2) if dif_window <= 0: raise ValueError('The duration of the signal for the baseline ' 'calculation should be slightly longer than the ' 'window time taken for the baseline calculation.') epochs, _ = make_windows(signal[dif_window:-dif_window, :] , self.fs, self.update_feature_window, self.update_rate) _, psd = scipy.signal.welch(epochs, self.fs, 'hamming', self.w_signal_samples, axis=1,scaling='density') b_power = self.power(psd) # Define artifact related thresholds if signal_artifacts is not None: self.thresholds = np.std( signal_artifacts[:, -self.w_signal_samples_calibration:, :], axis=1).mean(axis=-1) if self.mode == 'single': self.baseline_power = b_power[0] elif self.mode == 'ratio': self.baseline_power = b_power[0] / b_power[1] return self.baseline_power
[docs] def band_power(self, signal, signal_artifacts): """ This function returns the band power from Power Spectral Density. If signal noise is above the pre-established thresholds, this function will return None. Parameters ---------- signal: numpy.ndarray Signal pre-processed. [n_samples, n_channels]. signal_artifacts: numpy.ndarray Signal filtered in the frequency bands associated to the artifacts to be avoided. [n_artifact_bands, n_samples, n_channels]. Returns ------ b_power: float or None """ if self.baseline_power is None: raise ValueError('[PowerExtraction] Calibration not performed.') dif_window = round( (signal.shape[0] - self.w_signal_samples) / 2) if dif_window <= 0: raise ValueError('The duration of the signal for the feedback ' 'calculation should be slightly longer than the ' 'window time taken for the feedback calculation. ') _, psd = scipy.signal.welch(signal[dif_window:-dif_window], self.fs, 'hamming', self.w_signal_samples, axis=0,scaling='density') b_power_uncorrected = self.power(psd) if self.mode == 'single': b_power = b_power_uncorrected[0] - self.baseline_power elif self.mode == 'ratio': b_power = b_power_uncorrected[0] / b_power_uncorrected[ 1] - self.baseline_power # Check if artifact bands are defined if signal_artifacts is not None: if signal_artifacts is not None: if ignore_noisy_windows(signal_artifacts, self.thresholds, self.pct_tol): return b_power else: return None return b_power
[docs] def power(self, psd): """ This function calculates power from Power Spectral Density Parameters ---------- psd: numpy.ndarray [n_epochs, n_samples, n_channels]. Returns ------- powers: numpy.ndarray [n_training_bands]. """ # Check if psd has epochs dimension if len(psd.shape) == 2: psd = psd[np.newaxis,:,:] bands = [] # Extract training bands limits for dict in self.f_dict: if dict['type'] == 'training': bands.append(dict['cutoff']) powers = np.zeros(len(bands)) # Calculate band power relative to the whole bandwidth for idx, band in enumerate(bands): for b in band: powers[idx] += np.mean(np.mean(absolute_band_power(psd, self.fs, b), axis=0)) return powers
[docs]class ConnectivityBasedNFTModel(components.Algorithm):
[docs] def __init__(self, fs, filter_dict, l_baseline_t, update_feature_window, update_rate, montage, target_channels, fc_measure, mode, apply_car, pct_tol_ocular=None,pct_tol_muscular=None): super().__init__(calibration=['baseline_value'], training=['feedback_value']) """ Pipeline for Connectivity-based Neurofeedback training. This class inherits from components.Algorithm. Therefore, it can be used to create standalone algorithms that can be used in compatible apps from medusa-platform for online experiments. See components.Algorithm to know more about this functionality. """ # Settings self.fs = fs self.filter_dict = filter_dict self.l_baseline_t = l_baseline_t self.update_feature_window = update_feature_window self.montage = montage self.target_channels = target_channels self.fc_measure = fc_measure self.mode = mode self.apply_car = apply_car # Variables self.baseline_value = None self.pct_tol = None # Set percentage tolerance to noisy signal if pct_tol_ocular is None and pct_tol_muscular is not None: self.pct_tol = pct_tol_muscular elif pct_tol_ocular is not None and pct_tol_muscular is None: self.pct_tol = pct_tol_ocular elif pct_tol_ocular is not None and pct_tol_muscular is not None: self.pct_tol = np.array([pct_tol_ocular, pct_tol_muscular]) # Check filter dict if not self.check_cutoff_settings(): raise Exception('The number of frequency bands selected does not ' 'match the Neurofeedback mode.') # Add Pre-processing and Feature Extraction methods self.add_method('prep_method', SignalPreprocessing(filter_dict=self.filter_dict, montage=self.montage, target_channels=None, car=self.apply_car)) self.add_method('feat_ext_method', ConnectivityExtraction(fs=self.fs, l_baseline_t=self.l_baseline_t, update_feature_window=update_feature_window, fc_measure=self.fc_measure, mode=self.mode,montage=self.montage, target_channels=self.target_channels, pct_tol=self.pct_tol,update_rate=update_rate))
[docs] def calibration(self, eeg): """ It pre-process eeg, gets signal filtered in artifact-related bands and filters the pre-processed eeg in training band. Then, it calculates the baseline value. Parameters ---------- eeg: numpy.ndarray [n_samples, n_channels] Returns ------- baseline_value: float """ original_signal, signal_artifacts = self.get_inst('prep_method').\ prep_fit_transform( signal=eeg, fs=self.fs) narrow_filtered_signal = self.get_inst('prep_method').\ narrow_transform(signal=original_signal) self.baseline_value = self.get_inst('feat_ext_method').\ set_baseline(signal=original_signal,signal_artifacts=signal_artifacts, filtered_signal=narrow_filtered_signal)
[docs] def training(self, eeg): """ It pre-process eeg, gets signal filtered in artifact-related bands and filters the pre-processed eeg in training band. Then, it calculates the feedback value. Parameters ---------- eeg: numpy.ndarray [n_samples, n_channels] Returns ------- feedback_value: float """ original_signal, signal_artifacts = self.get_inst('prep_method').prep_transform( signal=eeg) narrow_filtered_signal = self.get_inst('prep_method'). \ narrow_transform(signal=original_signal) feedback_value = self.get_inst('feat_ext_method').ext_feature( signal=narrow_filtered_signal,signal_artifacts=signal_artifacts) return feedback_value
[docs] def check_cutoff_settings(self): """ Function to check the correct definition of training band dictionary. """ target_bands = 0 for filter in self.filter_dict: if filter['type'] == 'training': target_bands += 1 if target_bands == 1: return True else: return False
[docs]class PowerBasedNFTModel(components.Algorithm):
[docs] def __init__(self, fs, filter_dict, l_baseline_t, update_feature_window, update_rate,montage, target_channels, mode, apply_car, apply_laplacian, pct_tol_ocular=None, pct_tol_muscular=None): """ Pipeline for Power-based Neurofeedback training. This class inherits from components.Algorithm. Therefore, it can be used to create standalone algorithms that can be used in compatible apps from medusa-platform for online experiments. See components.Algorithm to know more about this functionality. """ super().__init__(calibration=['baseline_parameters'], training=['feedback_value']) """ Class constructor """ # Settings self.fs = fs self.filter_dict = filter_dict self.l_baseline_t = l_baseline_t self.update_feature_window = update_feature_window self.montage = montage self.target_channels = target_channels self.mode = mode self.apply_car = apply_car self.apply_laplacian = apply_laplacian # Init variables self.baseline_value = None self.pct_tol = None # Set percentage tolerance to noisy signal if pct_tol_ocular is None and pct_tol_muscular is not None: self.pct_tol = pct_tol_muscular elif pct_tol_ocular is not None and pct_tol_muscular is None: self.pct_tol = pct_tol_ocular elif pct_tol_ocular is not None and pct_tol_muscular is not None: self.pct_tol = np.array([pct_tol_ocular, pct_tol_muscular]) # # Check correct filter dict definition if not self.check_cutoff_settings(): raise Exception('The number of frequency bands selected does not ' 'match the Neurofeedback mode.') self.add_method('prep_method', SignalPreprocessing(filter_dict=self.filter_dict, montage=self.montage, target_channels=self.target_channels, laplacian=self.apply_laplacian, car=self.apply_car)) self.add_method('feat_ext_method', PowerExtraction(fs=fs, l_baseline_t=l_baseline_t, update_feature_window=update_feature_window, update_rate=update_rate,mode=self.mode, pct_tol=self.pct_tol,f_dict=filter_dict))
[docs] def calibration(self, eeg, **kwargs): """ It pre-process eeg and gets signal filtered in artifact-related bands. Then, it calculates the baseline value. Parameters ---------- eeg: numpy.ndarray [n_samples, n_channels] Returns ------- baseline_value: float """ original_signal, signal_artifacts = self.get_inst('prep_method').\ prep_fit_transform(signal=eeg, fs=self.fs) self.baseline_value = self.get_inst('feat_ext_method').set_baseline( signal=original_signal,signal_artifacts=signal_artifacts)
[docs] def training(self, eeg): """ It pre-process eeg, gets signal filtered in artifact-related bands. Then, it calculates the feedback value. Parameters ---------- eeg: numpy.ndarray [n_samples, n_channels] Returns ------- feedback_value: float """ original_signal, signal_artifacts = self.get_inst('prep_method').\ prep_transform(signal=eeg) feedback_value = self.get_inst('feat_ext_method').band_power( signal=original_signal, signal_artifacts=signal_artifacts) return feedback_value
[docs] def check_cutoff_settings(self): """ Function to check the correct definition of training band dictionary. """ target_bands = 0 for filter in self.filter_dict: if filter['type'] == 'training': target_bands += 1 if self.mode == 'single': if target_bands == 1: return True else: return False elif self.mode == 'ratio': if target_bands == 2: return True else: return False
[docs]class NeurofeedbackData(components.ExperimentData): """Experiment info class for Neurofeedback training experiments. It records the important events that take place during a Neurofeedback run, allowing offline analysis."""
[docs] def __init__(self, run_onsets, run_durations, run_success, run_pauses, run_restarts, medusa_nft_app_settings, nft_values, nft_times, nft_baseline): self.run_onsets = run_onsets self.run_durations = run_durations self.run_success = run_success self.run_pauses = run_pauses self.run_restarts = run_restarts self.medusa_nft_app_settings = medusa_nft_app_settings self.nft_values = nft_values self.nft_times = nft_times self.nft_baseline = nft_baseline
[docs] def to_serializable_obj(self): rec_dict = self.__dict__ for key in rec_dict.keys(): if type(rec_dict[key]) == np.ndarray: rec_dict[key] = rec_dict[key].tolist() return rec_dict
[docs] @classmethod def from_serializable_obj(cls, dict_data): return cls(**dict_data)