Source code for medusa.plots.erp_plots

"""Created on Friday October 01 10:09:11 2021

In this module you will find useful functions and classes to plot event-related
potentials (ERPs). This module is not finished, it has numerous improvement
points but can be useful for a quick plot. Enjoy!

@author: Eduardo Santamaría-Vázquez
"""
from medusa.bci import erp_spellers
import numpy as np
import copy


[docs]def plot_erp_from_erp_speller_dataset(erp_speller_dataset, channel, axes, window=(0, 1000)): """ Plots Event-Related Potentials (ERPs) from an ERP Speller dataset. Parameters ---------- erp_speller_dataset : ERPDataset The ERP Speller dataset containing EEG data and relevant information. channel : str The name of the EEG channel to be plotted. axes : matplotlib.axes.Axes The Matplotlib axes on which the ERP plot will be displayed. window : tuple, optional The time window for which the ERPs will be plotted, specified as a tuple (start_time, end_time). Default is (0, 1000) milliseconds. Returns ------- matplotlib.axes.Axes The Matplotlib axes on which the ERP plot is displayed. Raises ------ ValueError If the dataset is missing essential information or if the dataset mode is not set to 'train'. Notes ----- This function performs standard preprocessing and feature extraction on the input ERP Speller dataset and then plots the ERPs based on the specified channel and time window. Examples -------- >>> import matplotlib.pyplot as plt >>> from your_module import ERPDataset, plot_erp_from_erp_speller_dataset >>> # Assuming erp_speller_dataset is an instance of ERPDataset >>> fig, ax = plt.subplots() >>> plot_erp_from_erp_speller_dataset(erp_speller_dataset, channel='Fz', axes=ax) """ # Check errors if erp_speller_dataset.fs is None: raise ValueError('Define the fs of the dataset') if erp_speller_dataset.channel_set is None: raise ValueError('Define the channel set of the dataset') if erp_speller_dataset.experiment_mode != 'train': raise ValueError('The dataset mode must be train') # Create copy of the dataset dataset = copy.deepcopy(erp_speller_dataset) # Standard preprocessing preprocessing_pipeline = erp_spellers.StandardPreprocessing(cutoff=(1, 30)) dataset = preprocessing_pipeline.fit_transform_dataset(dataset) # Standard feature extraction feat_extraction_pipeline = erp_spellers.StandardFeatureExtraction( concatenate_channels=False, target_fs=None) x, x_info = feat_extraction_pipeline.transform_dataset(dataset) erp_labels = np.array(x_info['erp_labels']) # Call plot ERP return plot_erp(axes=axes, erp_epochs=x[erp_labels==1, :, :], noerp_epochs=x[erp_labels==0, :, :], channel=channel, window=window)
[docs]def plot_erp(axes, erp_epochs, noerp_epochs, channel, window=(0, 1000), error_measure="C95"): """Function designed to quickly plot an ERP with 95% confidence interval. It does offer limited functions that will be improved in the future. TODO: a lot of things, very basic functionality Parameters ---------- axes : matplotlib.Axes.axes Matplotlib axes in which the ERP will be displayed into. erp_epochs: numpy.ndarray Epochs that contain ERPs (go epochs) noerp_epochs: numpy.ndarray Epochs that do not contain ERPs (nogo epochs) channel: int Channel index to plot window: list List with the lower and upper window time in milliseconds error_measure: str Error measure (default: "C95" or 95% confidence interval). Check parameters of function compute_dev_epochs() for further information. plot: bool Set to True to plot the ERP Returns ------- erp_mean: numpy.ndarray ERP activity (mean of the go epochs) erp_dev: numpy.ndarray Error measure across observations for ERP activity noerp_mean: numpy.ndarray Non-ERP activity (mean of the nogo epochs) noerp_dev: numpy.ndarray Error measure across observations for non-ERP activity """ # Select channel erp_epochs = erp_epochs[:, :, channel] noerp_epochs = noerp_epochs[:, :, channel] # Calculate mean and dev measures trials_erp_mean = np.mean(erp_epochs, 0) trials_erp_dev_pos, trials_erp_dev_neg = \ compute_dev_epochs(erp_epochs, measure=error_measure) trials_noerp_mean = np.mean(noerp_epochs, 0) trials_noerp_dev_pos, trials_noerp_dev_neg = \ compute_dev_epochs(noerp_epochs, measure=error_measure) # Plot the data t = np.linspace(window[0], window[1], trials_erp_mean.shape[0]) axes.plot(t, trials_erp_mean) axes.fill_between(t, trials_erp_dev_neg, trials_erp_dev_pos, alpha=0.3) axes.plot(t, trials_noerp_mean) axes.fill_between(t, trials_noerp_dev_neg, trials_noerp_dev_pos, alpha=0.3) # Return data plot_data = dict() plot_data["trials_erp_mean"] = trials_erp_mean plot_data["trials_erp_dev"] = (trials_erp_dev_pos, trials_erp_dev_neg) plot_data["trials_noerp_mean"] = trials_noerp_mean plot_data["trials_noerp_dev"] = (trials_noerp_dev_pos, trials_noerp_dev_neg) return plot_data
[docs]def compute_dev_epochs(epochs, measure="C95"): """ Computes the error of a 2D data. Parameters ------------- epochs: ndarray Data being plotted, with dimensions [observations x signal] error: basestring Type of error being plotted (mean+error, mean-error), which can be: - 'std': standard deviation - 'sem': standard error mean - 'var': variance - Confidence interval: For this error, the measure parameter must be constituted by 'c' and the desired percentile. E.g. 'c95' for the 95% confidence interval, 'c90' for the 90%, 'c99' for the 99%, and so on. Returns ---------------- pos_deviation: ndarray 1D vector containing the positive deviation measure [1 x signal]. neg_deviation: ndarray 1D vector containing the negative deviation measure [1 x signal]. """ # Error detection measure = measure.upper() percentile = 95 if measure.startswith('C'): percentile = int(measure.split('C')[-1]) if percentile >= 100 or percentile <= 0: raise ValueError("[compute_dev_epochs] The confidence interval " "percentile (%i) must be in the range (0, 100)" % percentile) # Compute deviation measure if measure.startswith('C'): pos = np.percentile(epochs, percentile, axis=0) neg = np.percentile(epochs, 100 - percentile, axis=0) return pos, neg elif measure == "STD": pos = np.mean(epochs, axis=0) + np.std(epochs, axis=0) neg = np.mean(epochs, axis=0) - np.std(epochs, axis=0) return pos, neg elif measure == "VAR": pos = np.mean(epochs, axis=0) + np.var(epochs, axis=0) neg = np.mean(epochs, axis=0) - np.var(epochs, axis=0) return pos, neg else: raise ValueError("[compute_dev_epochs] Unknown deviation measure %s!" % measure)