"""
In this module you will find useful functions and classes to operate with data
recorded using spellers based on code-modulated visual evoked potentials
(c-VEP), which are widely used by the BCI community. Enjoy!
@author: Víctor Martínez-Cagigal
"""
import medusa as mds
from medusa import components
from medusa import meeg
from medusa import spatial_filtering as sf
from medusa import epoching as ep
import copy, warnings
import itertools
from abc import ABC, abstractmethod
import numpy as np
from tqdm import tqdm
LFSR_PRIMITIVE_POLYNOMIALS = \
{
'base': {
2: {
'order': {
2: [1, 1],
3: [1, 0, 1],
4: [1, 0, 0, 1],
5: [0, 1, 0, 0, 1],
6: [0, 0, 0, 0, 1, 1],
7: [0, 0, 0, 0, 0, 1, 1],
8: [1, 1, 0, 0, 0, 0, 1, 1],
9: [0, 0, 0, 1, 0, 0, 0, 0, 1],
10: [0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
11: [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
12: [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1],
13: [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1],
14: [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1],
15: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
16: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
17: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
18: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
19: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
1],
20: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0],
21: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0],
22: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1],
23: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0],
24: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 1, 1],
25: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0],
26: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 1, 1],
27: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 1, 1],
28: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
29: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
30: [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]
}
},
3: {
'order': {
2: [2, 1],
3: [0, 1, 2],
4: [0, 0, 2, 1],
5: [0, 0, 0, 1, 2],
6: [0, 0, 0, 0, 2, 1],
7: [0, 0, 0, 0, 2, 1, 2],
}
},
5: {
'order': {
2: [4, 3],
3: [0, 2, 3],
4: [0, 4, 3, 3],
}
},
7: {
'order': {
2: [1, 4]
}
},
11: {
'order': {
2: [1, 3]
}
},
13: {
'order': {
2: [1, 11]
}
}
}
}
# --------------------------- c-VEP DATA MANAGEMENT -------------------------- #
[docs]class CVEPSpellerData(components.ExperimentData):
"""Experiment info class for c-VEP-based spellers. It supports nested
multi-level paradigms. This unified class can be used to represent a run
of every c-VEP stimulation paradigm designed to date, and is the expected
class for feature extraction and command decoding functions of the module
medusa.bci.cvep_paradigms. It is complicated, but powerful so.. use it well!
"""
[docs] def __init__(self, mode, paradigm_conf, commands_info, onsets, command_idx,
unit_idx, level_idx, matrix_idx, cycle_idx, trial_idx,
cvep_model, spell_result, fps_resolution, spell_target=None,
raster_events=None, **kwargs):
# Check errors
if mode not in ('train', 'test'):
raise ValueError('Unknown mode. Possible values {train, test}')
# Standard attributes
self.mode = mode
self.paradigm_conf = paradigm_conf
self.commands_info = commands_info
self.onsets = onsets
self.command_idx = command_idx
self.unit_idx = unit_idx
self.level_idx = level_idx
self.matrix_idx = matrix_idx
self.cycle_idx = cycle_idx
self.trial_idx = trial_idx
self.cvep_model = cvep_model
self.spell_result = spell_result
self.fps_resolution = fps_resolution
self.spell_target = spell_target
self.raster_events = raster_events
# Optional attributes
for key, value in kwargs.items():
setattr(self, key, value)
[docs] def to_serializable_obj(self):
rec_dict = self.__dict__
for key in rec_dict.keys():
if type(rec_dict[key]) == np.ndarray:
rec_dict[key] = rec_dict[key].tolist()
return rec_dict
[docs] @classmethod
def from_serializable_obj(cls, dict_data):
return cls(**dict_data)
[docs]class CVEPSpellerDataset(components.Dataset):
""" This class inherits from medusa.data_structures.Dataset, increasing
its functionality for datasets with data from c-VEP-based spellers. It
provides common ground for the rest of functions in the module.
"""
[docs] def __init__(self, channel_set, fs=None, biosignal_att_key='eeg',
experiment_att_key='cvepspellerdata', experiment_mode=None,
track_attributes=None):
""" Constructor
Parameters
----------
channel_set : meeg.EEGChannelSet
EEG channel set. Only these channels will be kept in the dataset,
the others will be discarded. Also, the signals will be rearranged,
keeping the same channel order, avoiding errors in future stages of
the signal processing pipeline
fs : int, float or None
Sample rate of the recordings. If there are recordings with
different sample rates, the consistency of the dataset can be
still assured using resampling
biosignal_att_key : str
Name of the attribute containing the target biosignal that will be
used to extract the features. It has to be the same in all
recordings (e.g., 'eeg', 'meg').
experiment_att_key : str or None
Name of the attribute containing the target experiment that will be
used to extract the features. It has to be the same in all
recordings (e.g., 'rcp_data', 'cake_paradigm_data'). It is
mandatory when a recording of the dataset contains more than 1
experiment data
experiment_mode : str {'train'|'test'|None}
Mode of the experiment. If this dataset will be used to fit a model,
set to train to avoid errors
track_attributes: dict of dicts or None
This parameter indicates custom attributes that must be tracked in
feature extraction functions and how. The keys are the name of the
attributes, whereas the values are dicts indicating the tracking
mode {'concatenate'|'append'} and parent. Option concatenate is
only available for attributes of type list or numpy arrays,
forming a 1 dimensional array with the data from all recordings.
Option append is used to save all kind of objects for each
recording, forming a list whose length will be the number of
recordings in the dataset. A set of default attributes is defined,
so this parameter will be None in most cases. Example to track 2
custom attributes (i.e., date and experiment_equipment):
track_attributes = {
'date': {
'track_mode': 'append',
'parent': None
},
'experiment_equipment': {
'track_mode': 'append',
'parent': experiment_att_key
}
}
"""
# Check errors
if experiment_mode is not None:
if experiment_mode not in ('train', 'test'):
raise ValueError('Parameter experiment_mode must be '
'{train|test|None}')
# Default track attributes
default_track_attributes = {
'subject_id': {
'track_mode': 'append',
'parent': None
},
'paradigm_conf': {
'track_mode': 'append',
'parent': experiment_att_key
},
'commands_info': {
'track_mode': 'append',
'parent': experiment_att_key
},
'onsets': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'command_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'unit_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'level_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'matrix_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'cycle_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'trial_idx': {
'track_mode': 'concatenate',
'parent': experiment_att_key
},
'spell_result': {
'track_mode': 'append',
'parent': experiment_att_key
}
}
if experiment_mode == 'train':
default_track_attributes_train = {
'spell_target': {
'track_mode': 'append',
'parent': experiment_att_key
}
}
default_track_attributes = {
**default_track_attributes,
**default_track_attributes_train
}
track_attributes = \
default_track_attributes if track_attributes is None else \
{**default_track_attributes, **track_attributes}
# Class attributes
self.channel_set = channel_set
self.fs = fs
self.biosignal_att_key = biosignal_att_key
self.experiment_att_key = experiment_att_key
self.experiment_mode = experiment_mode
self.track_attributes = track_attributes
# Consistency checker
checker = self.__get_consistency_checker()
super().__init__(consistency_checker=checker)
def __get_consistency_checker(self):
"""Creates a standard consistency checker for c-VEP datasets
Returns
-------
checker : data_structures.ConsistencyChecker
Standard consistency checker for ERP feature extraction
"""
# Create consistency checker
checker = components.ConsistencyChecker()
# Check that the biosignal exists
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': self.biosignal_att_key}
)
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': self.biosignal_att_key,
'type': meeg.EEG}
)
# Check channels
checker.add_consistency_rule(
rule='check-values-in-attribute',
rule_params={'attribute': 'channels',
'values': self.channel_set.channels},
parent=self.biosignal_att_key + '.channel_set'
)
# Check sample rate
if self.fs is not None:
checker.add_consistency_rule(rule='check-attribute-value',
rule_params={'attribute': 'fs',
'value': self.fs},
parent=self.biosignal_att_key)
else:
warnings.warn('Parameter fs is None. The consistency of the '
'dataset cannot be assured. Still, you can use '
'target_fs parameter for feature extraction '
'and everything should be fine.')
# Check experiment
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': self.experiment_att_key}
)
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': self.experiment_att_key,
'type': CVEPSpellerData}
)
# Check mode
if self.experiment_mode is not None:
checker.add_consistency_rule(
rule='check-attribute-value',
rule_params={'attribute': 'mode',
'value': self.experiment_mode},
parent=self.experiment_att_key
)
# Check track_attributes
if self.track_attributes is not None:
for key, value in self.track_attributes.items():
checker.add_consistency_rule(
rule='check-attribute',
rule_params={'attribute': key},
parent=value['parent']
)
if value['track_mode'] == 'concatenate':
checker.add_consistency_rule(
rule='check-attribute-type',
rule_params={'attribute': key,
'type': [list, np.ndarray]},
parent=value['parent']
)
return checker
[docs] def custom_operations_on_recordings(self, recording):
# Select channels
eeg = getattr(recording, self.biosignal_att_key)
eeg.change_channel_set(self.channel_set)
return recording
# ---------------------------------- MODELS ---------------------------------- #
[docs]class CVEPModelCircularShifting(components.Algorithm):
[docs] def __init__(self, bpf=[[7, (1.0, 30.0)]], notch=[7, (49.0, 51.0)],
art_rej=None, correct_raster_latencies=False,
*args, **kwargs):
super().__init__()
if len(bpf) == 1:
if notch is not None:
self.add_method('prep_method', StandardPreprocessing(
bpf_order=bpf[0][0], bpf_cutoff=bpf[0][1],
notch_order=notch[0], notch_cutoff=notch[1]))
else:
self.add_method('prep_method', StandardPreprocessing(
bpf_order=bpf[0][0], bpf_cutoff=bpf[0][1],
notch_order=None, notch_cutoff=None))
else:
filter_bank = []
for i in range(len(bpf)):
filter_bank.append({
'order': bpf[i][0],
'cutoff': bpf[i][1],
'btype': 'bandpass'
})
if notch is not None:
self.add_method('prep_method', FilterBankPreprocessing(
filter_bank=filter_bank, notch_order=notch[0],
notch_cutoff=notch[1]))
else:
self.add_method('prep_method', FilterBankPreprocessing(
filter_bank=filter_bank, notch_order=None,
notch_cutoff=None))
# Feature extraction and classification (circular shifting)
self.add_method('clf_method', CircularShiftingClassifier(
art_rej=art_rej,
correct_raster_latencies=correct_raster_latencies
))
# Early stopping
self.add_method('es_method', CircularShiftingEarlyStopping())
[docs] def check_predict_feasibility(self, dataset):
return self.get_inst('clf_method')._is_predict_feasible(dataset)
[docs] def check_predict_feasibility_signal(self, times, onsets, fs):
return self.get_inst('clf_method')._is_predict_feasible_signal(
times, onsets, fs)
[docs] def fit_dataset(self, dataset, **kwargs):
# Safe copy
data = copy.deepcopy(dataset)
# Preprocessing
data = self.get_inst('prep_method').fit_transform_dataset(
dataset=data,
show_progress_bar=True
)
# Feature extraction and classification
fitted_info = self.get_inst('clf_method').fit_dataset(
dataset=data,
std_epoch_rejection=3.0,
show_progress_bar=True
)
return fitted_info
[docs] def predict_dataset(self, dataset):
# Safe copy
data = copy.deepcopy(dataset)
# Preprocessing
data = self.get_inst('prep_method').transform_dataset(
dataset=data,
show_progress_bar=True
)
# Feature extraction and classification
pred_items = self.get_inst('clf_method').predict_dataset(
dataset=data,
show_progress_bar=True
)
# Extract the selected items using the maximum number of cycles
selected_seq = 0 # todo: several sequences in the same matrix
spell_result = []
for item in pred_items:
spell_result.append(
item[-1][selected_seq]['sorted_cmds'][0]['label'])
# Extract the selected items depending on the number of cycles
spell_result_per_cycle = []
for item in pred_items:
trial_result_per_cycle = {}
for nc, pred in enumerate(item):
trial_result_per_cycle[nc] = pred[selected_seq][
'sorted_cmds'][0]['label']
spell_result_per_cycle.append(trial_result_per_cycle)
# Create the decoding dictionary
cmd_decoding = {
'spell_result': spell_result,
'spell_result_per_cycle': spell_result_per_cycle,
'items_by_no_cycle': pred_items
}
return cmd_decoding
[docs] def predict(self, times, signal, trial_idx, exp_data, sig_data):
# Safe copy
times_ = copy.deepcopy(times)
signal_ = copy.deepcopy(signal)
trial_idx_ = copy.deepcopy(trial_idx)
exp_data_ = copy.deepcopy(exp_data)
sig_data_ = copy.deepcopy(sig_data)
# Preprocessing
signal_ = self.get_inst('prep_method').transform_signal(
signal=signal_
)
# Feature extraction and classification
pred_item_by_no_cycles = self.get_inst('clf_method').predict(
times_, signal_, trial_idx_, exp_data_, sig_data_
)
# Extract the selected label using the maximum number of cycles
selected_seq = 0 # todo: several sequences in the same matrix
spell_result = pred_item_by_no_cycles[-1][selected_seq][
'sorted_cmds'][0]['label']
# Extract the selected label depending on the number of cycles
spell_result_per_cycle = {}
for nc, pred in enumerate(pred_item_by_no_cycles):
spell_result_per_cycle[nc] = pred[selected_seq]['sorted_cmds'][
0]['label']
# Create the decoding dictionary
cmd_decoding = {
'spell_result': spell_result,
'spell_result_per_cycle': spell_result_per_cycle,
'items_by_no_cycle': pred_item_by_no_cycles
}
return cmd_decoding
[docs] def must_stop(self, corr_vector, std=3.0):
# Safe copy
corr_vector_ = copy.deepcopy(corr_vector)
return self.get_inst('es_method').check_early_stop(
corr_vector=corr_vector,
std=std
)
# ------------------------------- ALGORITHMS -------------------------------- #
[docs]class StandardPreprocessing(components.ProcessingMethod):
"""Just the common preprocessing applied in c-VEP-based spellers. Simple,
quick and effective: frequency IIR band-pass and notch filters
"""
[docs] def __init__(self, bpf_order=7, bpf_cutoff=(0.5, 60.0), notch_order=7,
notch_cutoff=(49.0, 51.0)):
super().__init__(fit_transform_signal=['signal'],
fit_transform_dataset=['dataset'])
# Parameters
self.bpf_order = bpf_order
self.bpf_cutoff = bpf_cutoff
self.notch_order = notch_order
self.notch_cutoff = notch_cutoff
self.filt_method = 'sosfiltfilt'
# Variables
self.bpf_iir_filter = None
self.notch_iir_filter = None
[docs] def fit(self, fs):
"""Fits the IIR filters.
Parameters
----------
fs: float
Sample rate of the signal.
"""
# Bandpass
self.bpf_iir_filter = mds.IIRFilter(order=self.bpf_order,
cutoff=self.bpf_cutoff,
btype='bandpass',
filt_method=self.filt_method)
self.bpf_iir_filter.fit(fs)
# Notch
if self.notch_cutoff is not None:
self.notch_iir_filter = mds.IIRFilter(order=self.notch_order,
cutoff=self.notch_cutoff,
btype='bandstop',
filt_method=self.filt_method)
self.notch_iir_filter.fit(fs)
[docs]class FilterBankPreprocessing(components.ProcessingMethod):
"""Just the common preprocessing applied in c-VEP-based spellers. Simple,
quick and effective: frequency IIR band-pass and notch filters
"""
[docs] def __init__(self, filter_bank=None, notch_order=7,
notch_cutoff=(49.0, 51.0)):
super().__init__(fit_transform_signal=['signal'],
fit_transform_dataset=['dataset'])
if filter_bank is None:
filter_bank = [{'order': 7,
'cutoff': (8.0, 60.0),
'btype': 'bandpass'},
{'order': 7,
'cutoff': (12.0, 60.0),
'btype': 'bandpass'},
{'order': 7,
'cutoff': (30.0, 60.0),
'btype': 'bandpass'},
]
# Error check
if not filter_bank:
raise ValueError('[FilterBankPreprocessing] Filter bank parameter '
'"filter_bank" must be a list containing all '
'necessary information to perform the filtering!')
for filter in filter_bank:
if not isinstance(filter, dict):
raise ValueError('[FilterBankPreprocessing] Each filter must '
'be a dict()!')
if 'order' not in filter or \
'cutoff' not in filter or \
'btype' not in filter:
raise ValueError('[FilterBankPreprocessing] Each filter must '
'be a dict() containing the following keys: '
'"order", "cutoff" and "btype"!')
# Parameters
self.filter_bank = filter_bank
self.notch_order = notch_order
self.notch_cutoff = notch_cutoff
self.filt_method = 'sosfiltfilt'
# Variables
self.filter_bank_iir_filters = None
self.notch_iir_filter = None
[docs] def fit(self, fs):
"""Fits the IIR filters.
Parameters
----------
fs: float
Sample rate of the signal.
"""
# Filter bank
self.filter_bank_iir_filters = []
for filter in self.filter_bank:
iir = mds.IIRFilter(order=filter['order'],
cutoff=filter['cutoff'],
btype=filter['btype'],
filt_method=self.filt_method)
iir.fit(fs)
self.filter_bank_iir_filters.append(iir)
# Notch
if self.notch_cutoff is not None:
self.notch_iir_filter = mds.IIRFilter(order=self.notch_order,
cutoff=self.notch_cutoff,
btype='bandstop',
filt_method=self.filt_method)
self.notch_iir_filter.fit(fs)
[docs]class CircularShiftingClassifier(components.ProcessingMethod):
"""Standard feature extraction method for c-VEP-based spellers. Basically,
it computes a template for each sequence.
"""
[docs] def __init__(self, correct_raster_latencies=False, art_rej=None, **kwargs):
""" Class constructor """
super().__init__(fit_dataset=['templates',
'cca_by_seq'])
self.fitted = dict()
self.art_rej = art_rej
self.correct_raster_latencies = correct_raster_latencies
def _assert_consistency(self, dataset: CVEPSpellerDataset):
len_seqs = set()
fs = set()
fps_resolution = set()
unique_seqs_by_run = []
unique_all_seqs = set()
is_filter_bank = []
for rec in dataset.recordings:
rec_sig = getattr(rec, dataset.biosignal_att_key)
rec_exp = getattr(rec, dataset.experiment_att_key)
if rec_exp.mode != 'train':
raise ValueError('There is at least one CVEPSpellerData '
'instance that was not recording under train '
'mode. Aborting feature extraction...')
if not hasattr(rec_exp, 'spell_target'):
raise ValueError('There is at least one CVEPSpellerData '
'instance that has not "spell_target" data. '
'Aborting feature extraction...')
fps_resolution.add(rec_exp.fps_resolution)
fs.add(rec_sig.fs)
unique_seqs = get_unique_sequences_from_targets(rec_exp)
for seq_ in unique_seqs:
len_seqs.add(len(seq_))
unique_all_seqs.add(seq_)
unique_seqs_by_run.append(unique_seqs)
if isinstance(rec_sig.signal, list):
is_filter_bank.append(True)
else:
is_filter_bank.append(False)
if len(len_seqs) > 1:
raise ValueError('There are sequences with different lengths in '
'the CVEPSpellerDataset instance! Aborting feature'
' extraction...')
if len(fs) > 1:
raise ValueError('There are CVEPSpellerData instances with '
'different sampling rates! Aborting feature '
'extraction...')
if len(fps_resolution) > 1:
raise ValueError('There are CVEPSpellerData instances with '
'different refresh rates! Aborting feature '
'extraction...')
if len(unique_all_seqs) > 1:
# Check if some sequences are shifted versions of another
unique_ = list(unique_all_seqs)
all_combos = np.array(list(itertools.combinations(
np.arange(0, len(unique_)), 2)))
for i_comb in range(all_combos.shape[0]):
s1 = unique_[all_combos[i_comb][0]]
s2 = unique_[all_combos[i_comb][1]]
if check_if_shifted(s1, s2):
raise ValueError('There are targets that share shifted '
'versions of the same sequence. Solve that'
'before extracting features!')
if len(np.unique(is_filter_bank)):
is_filter_bank = is_filter_bank[0]
else:
raise ValueError('There are recordings that have filter banks and '
'other do not.')
len_seq = len_seqs.pop()
fs = fs.pop()
fps_resolution = fps_resolution.pop()
return fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank
[docs] def fit_dataset(self, dataset: CVEPSpellerDataset, std_epoch_rejection=3.0,
show_progress_bar=True):
# Error checking
fs, fps_resolution, len_seq, unique_seqs_by_run, is_filter_bank = \
self._assert_consistency(dataset)
# Init progress bar for sequences
if show_progress_bar:
pbar = tqdm(total=len(dataset.recordings),
desc='Extracting unique sequences')
# Compute sequence length in milliseconds
len_epoch_ms = len_seq / fps_resolution * 1000
len_epoch_sam = int(len_seq / fps_resolution * fs)
# Get the epochs of each sequence
epochs_by_seq = {}
for rec_idx, rec in enumerate(dataset.recordings):
# Extract recording experiment and biosignal
rec_exp = getattr(rec, dataset.experiment_att_key)
rec_sig = getattr(rec, dataset.biosignal_att_key)
# Filter bank init
if not is_filter_bank:
rec_sig.signal = [rec_sig.signal]
# Get unique sequences for this run
unique_seqs = unique_seqs_by_run[rec_idx]
# For each filter bank
for filter_idx, signal in enumerate(rec_sig.signal):
# Extract epochs
epochs = mds.get_epochs_of_events(timestamps=rec_sig.times,
signal=signal,
onsets=rec_exp.onsets,
fs=fs,
w_epoch_t=[0, len_epoch_ms],
w_baseline_t=None,
norm=None)
# Organize epochs by sequence
for seq_, ep_idxs_ in unique_seqs.items():
if tuple(seq_) not in epochs_by_seq:
epochs_by_seq[tuple(seq_)] = \
[None for i in range(len(rec_sig.signal))]
epochs_by_seq[tuple(seq_)][filter_idx] = \
epochs[ep_idxs_, :, :]
else:
if epochs_by_seq[tuple(seq_)][filter_idx] is None:
epochs_by_seq[tuple(seq_)][filter_idx] = \
epochs[ep_idxs_, :, :]
else:
epochs_by_seq[tuple(seq_)][
filter_idx] = np.concatenate((
epochs_by_seq[tuple(seq_)][filter_idx],
epochs[ep_idxs_, :, :]
), axis=0)
if show_progress_bar:
pbar.update(1)
# Precompute nearest channels for online artifact rejection
sorted_dist_ch = rec_sig.channel_set.sort_nearest_channels()
# New bar
if show_progress_bar:
pbar.close()
pbar = tqdm(total=len(epochs_by_seq) * len_seq,
desc='Creating templates')
# For each sequence
seq_dict = dict()
discarded_epochs = 0
total_epochs = 0
for seq_ in epochs_by_seq:
# For each filter of the bank
for filter_idx in range(len(epochs_by_seq[seq_])):
# Offline artifact rejection
if std_epoch_rejection is not None:
epochs_std = np.std(epochs_by_seq[seq_][filter_idx],
axis=1) # STD per samples
ch_std = np.std(epochs_std, axis=0) # Variation of epochs
# For each channel, check if the variation is adequate
epoch_to_keep = np.zeros(epochs_std.shape)
for i in range(len(ch_std)):
epoch_to_keep[:, i] = (
(epochs_std[:, i] < (
np.median(epochs_std[:, i]) +
std_epoch_rejection * ch_std[
i])) &
(epochs_std[:, i] > (
np.median(epochs_std[:, i]) -
std_epoch_rejection * ch_std[
i]))
)
# Keep only epochs that are suitable for all channels
idx_to_keep = (
np.sum(epoch_to_keep, axis=1) == epochs_std.shape[1]
)
epochs_by_seq[seq_][filter_idx] = \
epochs_by_seq[seq_][filter_idx][idx_to_keep, :, :]
discarded_epochs += np.sum(idx_to_keep == False)
total_epochs += len(idx_to_keep)
# Canonical Correlation Analysis
cca = sf.CCA()
# Reference (main template repeated 'no_cycles' times)
main_template = np.mean(epochs_by_seq[seq_][filter_idx], axis=0)
R = np.tile(main_template.T,
epochs_by_seq[seq_][filter_idx].shape[0]).T
# Input data (concatenated epochs)
X = epochs_by_seq[seq_][filter_idx]
X = X.reshape((X.shape[0] * X.shape[1], X.shape[2]))
# Fit CCA and project the main template
cca.fit(X, R)
main_template = cca.project(main_template, filter_idx=0,
projection='Wy')
# Create all possible template shifts
templates = dict()
for lag in range(len(seq_)):
lag_samples = int(np.round(lag / fps_resolution * fs))
lagged_seq = np.roll(seq_, -lag, axis=0)
lagged_template = np.roll(main_template, -lag_samples,
axis=0)
templates[tuple(lagged_seq)] = lagged_template
if show_progress_bar:
pbar.update(1)
# STD by channel (useful for online artifact rejection)
std_by_channel = np.std(X, axis=0)
# Store data of each trained sequence
if tuple(seq_) not in seq_dict:
seq_dict[tuple(seq_)] = []
seq_dict[tuple(seq_)].append(
{'cca': cca,
'templates': templates,
'std_by_channel': std_by_channel,
}
)
# Store fitted params
self.fitted = {'sequences': seq_dict,
'fs': fs,
'fps_resolution': fps_resolution,
'len_epoch_ms': len_epoch_ms,
'len_epoch_sam': len_epoch_sam,
'std_epoch_rejection': std_epoch_rejection,
'no_discarded_epochs': discarded_epochs,
'no_total_epochs': total_epochs,
'sorted_dist_ch': sorted_dist_ch
}
if show_progress_bar:
pbar.close()
return self.fitted
def _is_predict_feasible(self, dataset):
l_ms = self.fitted['len_epoch_ms']
for rec in dataset.recordings:
rec_sig = getattr(rec, dataset.biosignal_att_key)
rec_exp = getattr(rec, dataset.experiment_att_key)
feasible = ep.check_epochs_feasibility(timestamps=rec_sig.times,
onsets=rec_exp.onsets,
fs=rec_sig.fs,
t_window=[0, l_ms])
if feasible != 'ok':
return False
return True
def _is_predict_feasible_signal(self, times, onsets, fs):
l_ms = self.fitted['len_epoch_ms']
feasible = ep.check_epochs_feasibility(timestamps=times,
onsets=onsets,
fs=fs,
t_window=[0, l_ms])
if feasible != 'ok':
return False
return True
def _interpolate_epoch(self, epoch, channel_set, bad_channels_idx,
no_neighbors=3):
interp_epoch = epoch.copy()
# Are all channels bad?
if len(bad_channels_idx) == len(channel_set.channels):
print('> Artifact rejection: Cannot interpolate because all '
'channels are bad.')
return interp_epoch
# For each bad channel
bad_labels = []
for i in bad_channels_idx:
bad_labels.append(channel_set.channels[i]['label'])
for i, bad_label in enumerate(bad_labels):
if bad_label not in self.fitted['sorted_dist_ch']:
raise Exception('Label %s is not present in the EEGChannelSet'
' in which the model was fitted for!'
% bad_label)
# Find the K labels of the nearest neighbors
sorted_ch = self.fitted['sorted_dist_ch'][bad_label]
interp_labels = []
for ch in sorted_ch:
interp_labels.append(ch["channel"]["label"])
if len(interp_labels) == no_neighbors:
break
# Interpolate using average
interp_idxs = channel_set.get_cha_idx_from_labels(interp_labels)
interp_epoch[:, bad_channels_idx[i]] = \
np.mean(interp_epoch[:, np.array(interp_idxs)], axis=1)
print('> Artifact rejection: interpolated %i channels' %
len(bad_channels_idx))
return interp_epoch
[docs] def predict(self, times, signal, trial_idx, exp_data, sig_data):
# Parameters
len_epoch_ms = self.fitted['len_epoch_ms']
len_epoch_sam = self.fitted['len_epoch_sam']
fs = self.fitted['fs']
exp_data.onsets = np.array(exp_data.onsets)
# Assert filter bank
if not isinstance(signal, list):
signal = [signal]
for seq_, seq_data_ in self.fitted['sequences'].items():
if len(seq_data_) != len(signal):
raise ValueError('[CircularShiftingClassifier] Cannot predict '
'if the signal do not have the same number of '
'filter banks than the fitted one!')
# For each number of cycles
pred_item_by_no_cycles = []
no_cycles = np.max(exp_data.cycle_idx).astype(int) + 1
for nc in range(no_cycles):
# Identify what are the epochs that must be processed
idx = (np.array(exp_data.trial_idx) == trial_idx) & \
(np.array(exp_data.cycle_idx) <= nc)
# Raster latencies?
raster_dict = None
if self.correct_raster_latencies:
possible_onsets_idx = np.where(
exp_data.raster_events['onset'] <
exp_data.onsets[idx][-1]
)[0]
if possible_onsets_idx.size > 0:
raster_dict = exp_data.raster_events['event'][
possible_onsets_idx[-1]]
# For each fitted sequence
pred_item = []
for seq_, seq_data_ in self.fitted['sequences'].items():
# For each possible filter bank
f_corrs = []
for filter_idx, filter_signal in enumerate(signal):
# Extract the epochs for that signal, trial and no. cycles
epochs = mds.get_epochs_of_events(
timestamps=times,
signal=filter_signal,
onsets=exp_data.onsets[idx],
fs=fs,
w_epoch_t=[0, len_epoch_ms],
w_baseline_t=None,
norm=None)
if len(epochs.shape) == 2:
# Create a dummy dimension if we have only one epoch
epochs = np.expand_dims(epochs, 0)
# Artifact rejection
if self.art_rej is not None:
epoch_std_by_channel = \
np.std(epochs[:, :len_epoch_sam, :], axis=1)
for i in range(epoch_std_by_channel.shape[0]):
discard_epoch = epoch_std_by_channel[i, :] > \
seq_data_[filter_idx][
'std_by_channel'] \
* self.art_rej
if np.any(discard_epoch):
# TODO: precompute distance matrix before
epochs[i, :len_epoch_sam, :] = \
self._interpolate_epoch(
epoch=epochs[i, :len_epoch_sam, :],
channel_set=sig_data.channel_set,
bad_channels_idx=
np.where(discard_epoch)[0],
no_neighbors=3
)
# Average the epochs
avg = np.mean(epochs[:, :len_epoch_sam, :], axis=0)
# CCA projection
x_ = seq_data_[filter_idx]['cca'].project(
avg, filter_idx=0, projection='Wy')
# Correlation coefficients between x_ and the templates
corrs = []
seqs = []
for shift_seq_, template_ in \
seq_data_[filter_idx]['templates'].items():
# Correct template using raster latencies
lat_s = 0
if raster_dict is not None:
if shift_seq_ in raster_dict:
lat_s = int(raster_dict[shift_seq_] * fs)
tem_ = np.roll(template_, lat_s)
temp_p = np.dot(tem_, x_) / np.sqrt(np.dot(np.dot(
tem_, tem_), np.dot(x_, x_)))
corrs.append(temp_p)
seqs.append(shift_seq_)
f_corrs.append(corrs)
# Average the correlations between different filter banks
corrs = np.mean(np.array(f_corrs), axis=0)
seqs = np.array(seqs)
# Sort the sequences by corrs' descending order
sorted_idx = np.argsort(-corrs)
sorted_corrs = corrs[sorted_idx]
sorted_seqs = seqs[sorted_idx, :]
# Identify the selected command
sorted_cmds = get_items_by_sorted_sequences(
experiment=exp_data,
trial_idx=trial_idx,
sorted_seqs=sorted_seqs,
sorted_corrs=sorted_corrs
)
pred_item.append({
'sorted_cmds': sorted_cmds,
'fitted_sequence': seq_
})
# Store the predicted item
pred_item_by_no_cycles.append(pred_item)
return pred_item_by_no_cycles
[docs] def predict_dataset(self, dataset: CVEPSpellerDataset,
show_progress_bar=True):
# Error detection
if not self.fitted:
raise Exception(
'Cannot predict if circular shifting templates and '
'CCA projections are not fitted before! Aborting...')
for rec in dataset.recordings:
rec_sig = getattr(rec, dataset.biosignal_att_key)
rec_exp = getattr(rec, dataset.experiment_att_key)
if rec_sig.fs != self.fitted['fs']:
raise ValueError('The sampling rate of this test recording '
'(%.2f Hz) is not the same as for the fitted '
'recordings! (%.2f Hz)' %
(rec_sig.fs, self.fitted['fs']))
if rec_exp.fps_resolution != self.fitted['fps_resolution']:
raise ValueError('The refresh rate of this test recording '
'(%.2f Hz) is not the same as for the fitted '
'recordings! (%.2f Hz)' %
(rec_exp.fps_resolution,
self.fitted['fps_resolution']))
if show_progress_bar:
pbar = tqdm(total=len(dataset.recordings),
desc='Predicting dataset')
# For each recording
pred_items_by_no_cycles = []
for rec in dataset.recordings:
rec_sig = getattr(rec, dataset.biosignal_att_key)
rec_exp = getattr(rec, dataset.experiment_att_key)
# For each trial
for t_idx in np.unique(rec_exp.trial_idx):
decoding_by_no_cycles = \
self.predict(rec_sig.times, rec_sig.signal, t_idx,
rec_exp, rec_sig)
pred_items_by_no_cycles.append(decoding_by_no_cycles)
if show_progress_bar:
pbar.update(1)
if show_progress_bar:
pbar.close()
return pred_items_by_no_cycles
[docs]class CircularShiftingEarlyStopping(components.ProcessingMethod):
[docs] def __init__(self, **kwargs):
""" Class constructor """
super().__init__()
[docs] def check_early_stop(self, corr_vector, std=3.0):
""" Early stopping method based on normal distributions.
Parameters
--------------
corr_vector: list() or 1D ndarray
Vector that represents the sorted correlations for each of the
possible commands, where corr_vector[0] must point to the most
probable selected command.
std: int
Multiplier that determines if the selected command is an outlier
of the normal distribution made up from the rest of correlations.
Typical values are: 1 (outside 68% of data), 2 (outside 95% of
data), and 3 (default, outside 99.7% of data).
Returns
--------------
must_stop: bool
True if it is possible to stop now, false otherwise.
probs: 1D ndarray
Current estimated probabilities of being selected (sorted).
"""
corr_vector = np.array(corr_vector)
threshold = np.mean(corr_vector[1:]) + std * np.std(corr_vector[1:])
must_stop = corr_vector[0] > threshold
probs = threshold - corr_vector
probs = 1 - (probs / np.max(probs))
return must_stop, probs
# ------------------------------- UTILS -------------------------------------- #
[docs]def get_unique_sequences_from_targets(experiment: CVEPSpellerData):
""" Function that returns the unique sequences of all targets.
return
"""
sequences = dict()
try:
# todo: command_idx, unit_idx y demas lo tiene que hacer medusa y no unity
for idx in range(len(experiment.command_idx)):
# todo: revisar lo de los levels
# Get the sequence used for the current command
# l_ = int(experiment.level_idx[idx])
# u_ = int(experiment.unit_idx[idx])
# curr_command = experiment.paradigm_conf[m_][l_][u_][c_]
m_ = int(experiment.matrix_idx[idx])
c_ = int(experiment.command_idx[idx])
curr_seq_ = experiment.commands_info[m_][str(c_)]['sequence']
# Note: str(c_) is used because the previous json serialization
# interprets all dictionary keys as strings.
# Add the command index to its associated sequence
if tuple(curr_seq_) not in sequences:
sequences[tuple(curr_seq_)] = [idx]
else:
sequences[tuple(curr_seq_)].append(idx)
# todo: check that sequences are not shifted versions of themselves??
except Exception as e:
print(e)
return sequences
[docs]def check_if_shifted(seq1, seq2):
max_corr = np.max(np.correlate(seq1, seq1, 'same'))
cross_corr = np.correlate(seq1, seq2, 'same')
if np.max(cross_corr) == max_corr:
print('WARNING: Two sequences are shifted versions of themselves!')
return True
else:
return False
[docs]def get_items_by_sorted_sequences(experiment: CVEPSpellerData,
trial_idx, sorted_seqs, sorted_corrs=None):
# Find the first index of the trial to access matrix_idx, etc
try:
idx = list(experiment.trial_idx == trial_idx).index(True)
except ValueError as e:
raise ValueError('[get_items_by_sorted_sequences] Trial with idx %i not'
' found in the experiment data! ' + str(e) % trial_idx)
# Get the possible commands
m_ = int(experiment.matrix_idx[idx])
l_ = int(experiment.level_idx[idx])
u_ = int(experiment.unit_idx[idx])
possible_cmd = experiment.paradigm_conf[m_][l_][u_]
# For each sequence in descending order of probability of being selected
sorted_commands = list()
for i in range(sorted_seqs.shape[0]):
# For each possible command
curr_comm_dict = dict()
for cmd_id in possible_cmd:
cmd_seq = experiment.commands_info[m_][str(cmd_id)]['sequence']
if np.all(np.array(cmd_seq) == sorted_seqs[i, :]):
# Found!
curr_comm_dict['item'] = experiment.commands_info[m_][str(
cmd_id)]
curr_comm_dict['label'] = curr_comm_dict['item']['label']
curr_comm_dict['coords'] = [m_, l_, u_, int(cmd_id)]
curr_comm_dict['correlation'] = None
if sorted_corrs is not None:
curr_comm_dict['correlation'] = sorted_corrs[i]
sorted_commands.append(curr_comm_dict)
break
return sorted_commands
[docs]def autocorr_zeropad(x):
""" With zero padding, equivalent to np.correlate() """
N = len(x)
rxx = []
x_lagged = np.concatenate((np.zeros((N - 1,)), x, np.zeros((N,))))
for i in range(2 * N - 1):
rxx.append(np.sum(x * x_lagged[i:i + N]))
rxx = np.array(rxx)
return rxx
[docs]def autocorr_circular(x):
""" With circular shifts (periodic correlation) """
N = len(x)
rxx = []
for i in range(-(N - 1), N):
rxx.append(np.sum(x * np.roll(x, i)))
rxx = np.array(rxx)
return rxx
# ----------------------------- CODE GENERATORS ----------------------------- #
[docs]class LFSR:
""" Computes a Linear-Feedback Shift Register (LFSR) sequence. """
[docs] def __init__(self, polynomial, base=2, seed=None, center=False):
""" Constructor of LFSR """
self.polynomial = polynomial
self.base = base
self.seed = seed
self.order = len(polynomial)
self.N = base ** self.order - 1
self.sequence = self.lfsr(polynomial, base, seed, center)
[docs] @staticmethod
def lfsr(polynomial, base=2, seed=None, center=False):
""" This method implements a Linear-Feedback Shift Register (LFSR).
IMPORTANT: maximal length sequences (m-sequences) can be only generated
if the polynomial (taps) is primitive. I.e.:
- the number of taps is even.
- the set of taps is setwise co-prime (there must be no divisor
other than 1 common to all taps).
A list of primitive polynomials in function of the order m can be
found here:
https://en.wikipedia.org/wiki/Linear-feedback_shift_register
NOTE: if the seed is composed by all zeros, the output sequence will be
zeros.
Parameters
----------
polynomial: list
Generator polynomial. E.g. (bias is specified for math convention
but not used):
"1 + x^5 + x^6" would be [0, 0, 0, 0, 1, 1].
"1 + 2x + x^4" would be [2, 0, 0, 1]
base : int
(Optional, default: base = 2) Base of the sequence events that
belongs to the Galois Field of the same base. By default, base=2,
so only events of type {0,1} (or {-1,1} are part of the returned
sequence.
seed : list
(Optional) Initial state. If not provided, the default state is a
one-array with length equal to the order of the polynomial.
center : bool
(Optional, default = False) Determines if a centering over zero
must be performed in the returned sequence (e.g., {0,1} -> {-1,1})
Returns
-------
sequence : list
LFSR m-sequence (maximum length is base^order-1), where order is the
order of the polynomial.
"""
# Defaults and error detection
order = len(polynomial)
if seed is None:
seed = [1 for i in range(order)]
if order > len(seed):
raise Exception('[LSFR] The order of the polynom (%i) is higher '
'than the initial state length (%i)!' %
(order, len(seed)))
# LFSR
sequence = seed.copy()
polynom = np.array(polynomial)
while len(sequence) < base ** order - 1:
new_bit = (np.matmul(polynom, np.array(sequence[:order]))) % base
sequence.insert(0, new_bit)
# Map the values to center around zero
if center:
if base == 2:
sequence = np.array(sequence) * 2 - 1
else:
sequence = np.array(sequence) - np.floor(base / 2).astype(int)
return sequence