from medusa import frequency_filtering as ff
from medusa import spatial_filtering as sf
# from medusa.storage.medusa_data import MedusaData
from medusa.components import Recording
# from medusa.bci.mi_feat_extraction import extract_mi_trials_from_midata
from medusa.local_activation import statistics
# from medusa.bci.mi_models import MIModelSettings
from medusa.plots import head_plots
from medusa.bci.mi_paradigms import StandardPreprocessing, \
StandardFeatureExtraction, MIDataset
import numpy as np
from scipy.ndimage import uniform_filter1d
import matplotlib.pyplot as plt
import scipy.signal as scisig
import copy
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import fdrcorrection
[docs]class MIPlots:
# TODO: Currently only 2 classes are supported and hardcoded, extend!
[docs] def __init__(self):
self.filter_type = "FIR"
self.filter_order = 1000
self.filter_btype = "bandpass"
self.filter_filt_method = "filtfilt"
self.apply_car = True
# Generated
self.raw_dataset = None
self.fs = None
self.channel_set = None
self.set_sizes()
[docs] def set_sizes(self, label_size=6, axes_size=5, line_width=1):
self.label_size = label_size
self.axes_size = axes_size
self.line_width = line_width
[docs] def set_dataset(self, files):
""" Call this method to configure a dataset. It must be called before
plotting anything.
Parameters
----------
files: list()
List of mi.bson files.
"""
rec = Recording.load(files[0])
self.fs = rec.eeg.fs
self.channel_set = rec.eeg.channel_set
self.raw_dataset = MIDataset(channel_set=self.channel_set, fs=self.fs,
experiment_att_key='midata',
biosignal_att_key='eeg',
experiment_mode='train')
for file in files:
self.raw_dataset.add_recordings(Recording.load(file))
[docs] def plot_spectrogram(self, ch_to_plot, axs_to_plot=None,
f_lims=(8, 30),
t_trial_window=(1000, 6000),
t_reference_window=(-1000, 0),
welch_seg_len_pct=30,
welch_overlap_pct=96):
# TODO: in process
def _extract_features(full_window):
if self.filter_type == "IIR":
filter = ff.IIRFilter(order=self.filter_order,
cutoff=f_lims,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
else:
filter = ff.FIRFilter(order=self.filter_order,
cutoff=f_lims,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
filter.fit(fs=self.fs)
dataset = copy.deepcopy(self.raw_dataset)
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = filter.transform(signal=eeg.signal)
if self.apply_car:
eeg.signal = sf.car(signal=eeg.signal)
setattr(rec, dataset.biosignal_att_key, eeg)
feature_extractor = StandardFeatureExtraction()
trials, track_info = feature_extractor.transform_dataset(
dataset=dataset,
w_epoch_t=full_window,
target_fs=None, baseline_mode=None,
safe_copy=False, w_baseline_t=None, norm=None
)
return trials, track_info
if self.raw_dataset is None:
raise Exception("Call MiPlots._extract_features() before plotting!")
if axs_to_plot is None:
axs_to_plot = list()
fig = plt.figure(figsize=(7.5, 3), dpi=300)
l_ = len(ch_to_plot)
for c in range(ch_to_plot):
gs = fig.add_gridspec(3, l_, wspace=0.2, hspace=0.2,
height_ratios=[1, 1])
axs_to_plot.append({'spec_c1': fig.add_subplot(gs[0, c]),
'spec_c2': fig.add_subplot(gs[1, c])})
# Get features
t1 = np.min([t_trial_window + t_reference_window])
t2 = np.max([t_trial_window + t_reference_window])
full_trials, track_info = _extract_features(full_window=(t1, t2))
labels = track_info["mi_labels"]
labels_info = track_info["mi_labels_info"][0]
# Separate the classes
trials_c1 = full_trials[labels == 0, :, :]
trials_c2 = full_trials[labels == 1, :, :]
# Compute the spectrogram
welch_seg_len = np.round(welch_seg_len_pct / 100 *
full_trials.shape[1]).astype(int)
welch_overlap = np.round(welch_overlap_pct / 100 *
welch_seg_len).astype(int)
freqs, times, spec_c1 = scisig.spectrogram(
trials_c1, axis=1, fs=self.fs, nperseg=welch_seg_len,
noverlap=welch_overlap, nfft=welch_seg_len)
freqs, times, spec_c2 = scisig.spectrogram(
trials_c2, axis=1, fs=self.fs, nperseg=welch_seg_len,
noverlap=welch_overlap, nfft=welch_seg_len)
s_win = ((np.array(t_trial_window) - t1) * times.shape[0] /
(t2 - t1)).astype(int)
r_win = ((np.array(t_reference_window) - t1) * times.shape[0] /
(t2 - t1)).astype(int)
# ERD/ERS (%) [n_trials x time x n_cha x frequency]
power_trial_c1 = np.power(spec_c1, 2)
R_c1 = np.mean(np.power(spec_c1[:, r_win[0]:r_win[1], :, :], 2), axis=1)
R_c1 = np.expand_dims(R_c1, axis=1)
erders_c1 = 100 * (power_trial_c1 - R_c1) / R_c1
power_trial_c2 = np.power(spec_c2, 2)
R_c2 = np.mean(np.power(spec_c2[:, r_win[0]:r_win[1], :, :], 2), axis=1)
R_c2 = np.expand_dims(R_c2, axis=1)
erders_c2 = 100 * (power_trial_c2 - R_c2) / R_c2
# Average
erders_c1_avg = np.mean(erders_c1, axis=0)
erders_c2_avg = np.mean(erders_c2, axis=0)
# Plot
c_lims = [10, -10]
for n in range(len(ch_to_plot)):
if ch_to_plot[n] not in self.raw_dataset.channel_set.l_cha:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = self.raw_dataset.channel_set.l_cha.index(ch_to_plot[n])
if np.max(erders_c1_avg[:, i, :]) > c_lims[1]:
c_lims[1] = np.max(erders_c1_avg[:, i, :])
if np.max(erders_c2_avg[:, i, :]) > c_lims[1]:
c_lims[1] = np.max(erders_c2_avg[:, i, :])
if np.min(erders_c1_avg[:, i, :]) < c_lims[0]:
c_lims[0] = np.min(erders_c1_avg[:, i, :])
if np.min(erders_c2_avg[:, i, :]) < c_lims[0]:
c_lims[0] = np.min(erders_c2_avg[:, i, :])
for n in range(len(ch_to_plot)):
if ch_to_plot[n] not in self.raw_dataset.channel_set.l_cha:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = self.raw_dataset.channel_set.l_cha.index(ch_to_plot[n])
with plt.style.context('seaborn'):
# Averaged curves
if "spec_c1" in axs_to_plot[n]:
ax1 = axs_to_plot[n]['spec_c1']
ax1.minorticks_on()
im = ax1.pcolormesh(times, freqs,
np.squeeze(erders_c1_avg[:, i, :]),
cmap='RdBu_r', vmin=c_lims[0],
vmax=c_lims[1])
ax1.set_ylabel('Frequency (Hz)', fontsize=self.label_size)
ax1.set_xlabel('Time (ms)', fontsize=self.label_size)
ax1.tick_params(axis='x', labelsize=self.axes_size)
ax1.tick_params(axis='y', labelsize=self.axes_size)
ax1.set_ylim(f_lims)
# TODO: add colorbar
if "spec_c2" in axs_to_plot[n]:
ax2 = axs_to_plot[n]['spec_c2']
ax2.minorticks_on()
im = ax2.pcolormesh(times, freqs,
np.squeeze(erders_c2_avg[:, i, :]),
cmap='RdBu_r', vmin=c_lims[0],
vmax=c_lims[1])
ax2.set_ylabel('Frequency (Hz)', fontsize=self.label_size)
ax2.set_xlabel('Time (ms)', fontsize=self.label_size)
ax2.tick_params(axis='x', labelsize=self.axes_size)
ax2.tick_params(axis='y', labelsize=self.axes_size)
ax2.set_ylim(f_lims)
# TODO: add colorbar
return axs_to_plot
[docs] def plot_erd_ers_freq(self, ch_to_plot, axs_to_plot=None,
f_lims=(5, 40),
f_sel=(8, 13),
t_trial_window=(1000, 6000),
welch_seg_len_pct=50,
welch_overlap_pct=75,
mov_mean_hz=0):
""" Plots the ERD/ERS in the frequency domain.
Parameters
-----------
ch_to_plot: list(basestring)
List of channels to be plotted (commonly: ["C3", "C4"])
axs_to_plot: list(dict())
List of dictionaries. Each dictionary belongs to each channel to
be plotted, thus the number of dictionaries must be equal to
len(ch_to_plot). Dictionary must have the following items:
- "freq" (matplotib.axes): axes for the PSD plot.
- "r2" (matplotlib.axes): axes for the statistic r2 plot.
- "pval" (matplotlib.axes): axes for the p-values plot (
p-values are computed using a Wilcoxon signed-rank test and
correcting FDR using Benjamini-Hochberg).
f_lims: tuple()
Tuple containing the frequency limits of the PSD (i.e.,
the filtering cutoff)
f_sel: tuple()
Tuple containing the desired frequency band for visualization
purposes (it will be indicated with a shaded area)
t_trial_window: tuple()
Trial window to be considered in ms.
welch_seg_len_pct: int
Percentage of the trial window used to extract segments for the
PSD Welch estimation
welch_overlap_pct: int
Percentage of the segment length used to overlap segments.
mov_mean_hz : int
Resolution (Hz/bin) desired after a moving average filter (use 0
to avoid the smooth filtering).
Returns
-----------
matplotlib.axes:
Modified axes
"""
def _extract_features():
dataset = copy.deepcopy(self.raw_dataset)
if self.filter_type == "IIR":
filter = ff.IIRFilter(order=self.filter_order,
cutoff=f_lims,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
else:
filter = ff.FIRFilter(order=self.filter_order,
cutoff=f_lims,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
filter.fit(fs=self.fs)
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = filter.transform(signal=eeg.signal)
if self.apply_car:
eeg.signal = sf.car(signal=eeg.signal)
setattr(rec, dataset.biosignal_att_key, eeg)
feature_extractor = StandardFeatureExtraction()
features, track_info = feature_extractor.transform_dataset(
dataset=dataset,
w_epoch_t=t_trial_window,
target_fs=None, baseline_mode=None,
safe_copy=False, w_baseline_t=None, norm=None
)
return features, track_info
if self.raw_dataset is None:
raise Exception("Call MiPlots.set_dataset() before plotting!")
if axs_to_plot is None:
axs_to_plot = list()
fig = plt.figure(figsize=(7.5, 3), dpi=300)
l_ = len(ch_to_plot)
for c in range(ch_to_plot):
gs = fig.add_gridspec(3, l_, wspace=0.2, hspace=0.2,
height_ratios=[1, 0.1, 0.1])
axs_to_plot.append({'freq': fig.add_subplot(gs[0, c]),
'r2': fig.add_subplot(gs[1, c]),
'pval': fig.add_subplot(gs[2, c])})
# Get features
features, track_info = _extract_features()
labels = track_info["mi_labels"]
labels_info = track_info["mi_labels_info"][0]
# Compute the PSD
welch_seg_len = np.round(welch_seg_len_pct / 100 *
features.shape[1]).astype(int)
welch_overlap = np.round(welch_overlap_pct / 100 *
welch_seg_len).astype(int)
trials_psd = None
for t in features:
# Compute PSD of the trial
t_freqs, t_psd = scisig.welch(t, fs=self.fs, nperseg=welch_seg_len,
noverlap=welch_overlap,
nfft=welch_seg_len, axis=0)
# Concatenate
t_psd = np.expand_dims(t_psd, axis=0)
trials_psd = np.concatenate((trials_psd, t_psd), axis=0) if \
trials_psd is not None else t_psd
# Smoothing?
if mov_mean_hz != 0:
size = int(mov_mean_hz * trials_psd.shape[1] / (0.5 * self.fs))
trials_psd = uniform_filter1d(trials_psd, size, axis=0)
# Separate the classes
trials_psd_c1 = trials_psd[labels == 0, :, :]
trials_psd_c2 = trials_psd[labels == 1, :, :]
# Signed r2
trials_r2 = statistics.signed_r2(trials_psd_c1, trials_psd_c2,
signed=False, axis=0)
# Wilcoxon signed-rank test
trials_p = wilcoxon(trials_psd_c1, trials_psd_c2, axis=0)
trials_p = trials_p.pvalue
trials_p_fdr = np.zeros(trials_p.shape)
for j in range(self.raw_dataset.channel_set.n_cha):
# Correct FDR (Benjamini-Hochberg)
_, p_ = fdrcorrection(trials_p[:, j], alpha=0.05, is_sorted=False)
trials_p_fdr[:, j] = p_
# Mean PSD
m_psd_c1 = np.mean(trials_psd_c1, axis=0)
m_psd_c2 = np.mean(trials_psd_c2, axis=0)
# Plot
freqs = np.linspace(0, self.fs / 2, len(m_psd_c1))
for n in range(len(ch_to_plot)):
if ch_to_plot[n] not in self.channel_set.l_cha:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = self.channel_set.l_cha.index(ch_to_plot[n])
with plt.style.context('seaborn'):
# Averaged curves
if "freq" in axs_to_plot[n]:
ax1 = axs_to_plot[n]['freq']
ax1.minorticks_on()
ax1.grid(visible=True, which='minor', color='#ededed',
linestyle='--')
ax1.grid(visible=True, which='major')
# Selected band box
mi_ = np.min([np.min(m_psd_c1[:, i]),
np.min(m_psd_c2[:, i])])
ma_ = np.max([np.max(m_psd_c1[:, i]),
np.max(m_psd_c2[:, i])])
off_ = 0.1 * (ma_ - mi_)
tx = (f_sel[0], f_sel[1], f_sel[1], f_sel[0])
ty = (mi_ - off_, mi_ - off_, ma_ + off_, ma_ + off_)
ax1.fill(tx, ty, edgecolor=None, facecolor="#D1F0FF55",
label='_nolegend_')
# Lines
ax1.plot(freqs, m_psd_c1[:, i], linewidth=self.line_width,
color=[255 / 255, 174 / 255, 0 / 255])
ax1.plot(freqs, m_psd_c2[:, i], linewidth=self.line_width,
color=[24 / 255, 255 / 255, 73 / 255])
ax1.set_xlim(f_lims)
ax1.set_ylim([mi_ - off_, ma_ + off_])
ax1.set_title(ch_to_plot[n], fontsize=self.label_size)
ax1.set_ylabel(r'PSD ($uV^2/Hz$)', fontsize=self.label_size)
ax1.legend([labels_info[str(0)], labels_info[str(1)]],
fontsize=self.label_size)
ax1.set_xlabel('Frequency (Hz)', fontsize=self.label_size)
ax1.tick_params(axis='x', labelsize=self.axes_size)
ax1.tick_params(axis='y', labelsize=self.axes_size)
# Signed-r2
if "r2" in axs_to_plot[n]:
ax2 = axs_to_plot[n]['r2']
ax2.pcolormesh(freqs, range(2),
np.tile(trials_r2[:, i], reps=[2, 1]),
cmap='YlOrRd',
vmin=0)
ax2.set_xlim(f_lims)
ax2.set_ylabel('$r^2$', fontsize=self.label_size)
ax2.set_xlabel('Frequency (Hz)', fontsize=self.label_size)
ax2.tick_params(axis='x', labelsize=self.axes_size)
ax2.tick_params(axis='y', labelsize=self.axes_size)
ax2.get_yaxis().set_ticks([])
# P-value < 0.05
if "pval" in axs_to_plot[n]:
ax3 = axs_to_plot[n]['pval']
ax3.minorticks_on()
ax3.pcolormesh(freqs, range(2),
np.tile(trials_p_fdr[:, i] <= 0.05,
reps=[2, 1]),
cmap='binary',
vmin=0, vmax=0.05)
ax3.set_xlim(f_lims)
ax3.set_ylabel('$p$-val', fontsize=self.label_size)
ax3.set_xlabel('Frequency (Hz)', fontsize=self.label_size)
ax3.tick_params(axis='x', labelsize=self.axes_size)
ax3.tick_params(axis='y', labelsize=self.axes_size)
ax3.get_yaxis().set_ticks([])
return axs_to_plot
[docs] def plot_erd_ers_time(self, ch_to_plot, axs_to_plot=None,
t_trial_window=(1000, 6000),
t_reference_window=(-1000, 0),
f_cutoff=(8, 13),
mov_mean_ms=1000):
""" Plots the ERD/ERS in the temporal domain.
Parameters
-----------
ch_to_plot: list(basestring)
List of channels to be plotted (commonly: ["C3", "C4"]).
axs_to_plot: list(dict())
List of dictionaries. Each dictionary belongs to each channel to
be plotted, thus the number of dictionaries must be equal to
len(ch_to_plot). Dictionary must have the following items:
- "freq" (matplotib.axes): axes for the PSD plot.
- "r2" (matplotlib.axes): axes for the statistic r2 plot.
- "pval" (matplotlib.axes): axes for the p-values plot (
p-values are computed using a Wilcoxon signed-rank test and
correcting FDR using Benjamini-Hochberg).
t_trial_window: tuple()
Trial window to be considered in ms (relative to the onset).
t_reference_window: tuple()
Reference window to be considered in ms (relative to the onset).
f_cutoff: tuple()
Filter cutoff to select a desired band
mov_mean_ms : int
Resolution (ms/bin) desired after a moving average filter (use 0
to avoid the smooth filtering).
Returns
-----------
matplotlib.axes:
Modified axes
"""
def _extract_features(full_window):
# all_trials : from reference to end of trial window
# references : only reference window
dataset = copy.deepcopy(self.raw_dataset)
if self.filter_type == "IIR":
filter = ff.IIRFilter(order=self.filter_order,
cutoff=f_cutoff,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
else:
filter = ff.FIRFilter(order=self.filter_order,
cutoff=f_cutoff,
btype=self.filter_btype,
filt_method=self.filter_filt_method)
filter.fit(fs=self.fs)
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = filter.transform(signal=eeg.signal)
if self.apply_car:
eeg.signal = sf.car(signal=eeg.signal)
setattr(rec, dataset.biosignal_att_key, eeg)
feature_extractor = StandardFeatureExtraction()
all_trials, track_info = feature_extractor.transform_dataset(
dataset=dataset,
w_epoch_t=full_window,
target_fs=None, baseline_mode=None,
safe_copy=False, w_baseline_t=None, norm=None
)
return all_trials, track_info
if self.raw_dataset is None:
raise Exception("Call MiPlots._extract_features() before plotting!")
if axs_to_plot is None:
axs_to_plot = list()
fig = plt.figure(figsize=(7.5, 3), dpi=300)
l_ = len(ch_to_plot)
for c in range(ch_to_plot):
gs = fig.add_gridspec(3, l_, wspace=0.2, hspace=0.2,
height_ratios=[1, 0.1, 0.1])
axs_to_plot.append({'time': fig.add_subplot(gs[0, c]),
'r2': fig.add_subplot(gs[1, c]),
'pval': fig.add_subplot(gs[2, c])})
# Get features
t1 = np.min([t_trial_window + t_reference_window])
t2 = np.max([t_trial_window + t_reference_window])
fw = (t1, t2)
full_trials, track_info = _extract_features(full_window=(t1, t2))
labels = track_info["mi_labels"]
labels_info = track_info["mi_labels_info"][0]
# Separate the classes
s_win = ((np.array(t_trial_window) - t1) * full_trials.shape[1] /
(t2 - t1)).astype(int)
r_win = ((np.array(t_reference_window) - t1) * full_trials.shape[1] /
(t2 - t1)).astype(int)
trials_c1 = full_trials[labels == 0, :, :]
trials_c2 = full_trials[labels == 1, :, :]
# Compute the ERD/ERS
power_trial_c1 = np.power(trials_c1, 2)
power_ref_avg_c1 = np.mean(np.mean(
np.power(trials_c1[:, r_win[0]:r_win[1], :], 2), axis=0), axis=0)
erders_c1 = 100 * (power_trial_c1 - power_ref_avg_c1) / power_ref_avg_c1
power_trial_c2 = np.power(trials_c2, 2)
power_ref_avg_c2 = np.mean(np.mean(
np.power(trials_c2[:, r_win[0]:r_win[1], :], 2), axis=0), axis=0)
erders_c2 = 100 * (power_trial_c2 - power_ref_avg_c2) / power_ref_avg_c2
# Smoothing?
if mov_mean_ms != 0:
size = int(mov_mean_ms * self.fs / 1000)
erders_c1 = uniform_filter1d(erders_c1, size, axis=1)
erders_c2 = uniform_filter1d(erders_c2, size, axis=1)
# Signed r2
trials_r2 = statistics.signed_r2(erders_c1, erders_c2, signed=False,
axis=0)
# Wilcoxon signed-rank test
trials_p = wilcoxon(erders_c1, erders_c2, axis=0)
trials_p = trials_p.pvalue
trials_p_fdr = np.zeros(trials_p.shape)
for j in range(self.raw_dataset.channel_set.n_cha):
# Correct FDR (Benjamini-Hochberg)
_, p_ = fdrcorrection(trials_p[:, j], alpha=0.05, is_sorted=False)
trials_p_fdr[:, j] = p_
# Plot
erders_c1_avg = np.mean(erders_c1, axis=0)
erders_c2_avg = np.mean(erders_c2, axis=0)
lcha = self.raw_dataset.channel_set.l_cha
times = np.linspace(t1, t2, erders_c1.shape[1])
for n in range(len(ch_to_plot)):
if ch_to_plot[n] not in lcha:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = lcha.index(ch_to_plot[n])
with plt.style.context('seaborn'):
# Averaged curves
if "time" in axs_to_plot[n]:
ax1 = axs_to_plot[n]['time']
ax1.minorticks_on()
ax1.grid(visible=True, which='minor', color='#ededed',
linestyle='--')
ax1.grid(visible=True, which='major')
# Reference box
m_ = np.max([
np.max(np.abs(erders_c1_avg[:, i])),
np.max(np.abs(erders_c2_avg[:, i]))
])
rx = (t_reference_window[0], t_reference_window[1],
t_reference_window[1], t_reference_window[0])
ry = (-0.2 * m_, -0.2 * m_, 0.2 * m_, 0.2 * m_)
ax1.fill(rx, ry, edgecolor=None, facecolor="#FFBEF055",
label='_nolegend_')
# Trial box
mi_ = np.min([np.min(erders_c1_avg[:, i]),
np.min(erders_c2_avg[:, i])])
ma_ = np.max([np.max(erders_c1_avg[:, i]),
np.max(erders_c2_avg[:, i])])
off_ = 0.1 * (ma_ - mi_)
tx = (t_trial_window[0], t_trial_window[1],
t_trial_window[1], t_trial_window[0])
ty = (mi_ - off_, mi_ - off_, ma_ + off_, ma_ + off_)
ax1.fill(tx, ty, edgecolor=None, facecolor="#D1F0FF55",
label='_nolegend_')
# Onset line
ax1.plot((0, 0), (mi_ - off_, ma_ + off_), '--k',
linewidth=self.line_width/2,
label='_nolegend_')
# ERD/ERS lines
l1 = ax1.plot(times, erders_c1_avg[:, i],
linewidth=self.line_width,
color=[255 / 255, 174 / 255, 0 / 255])
l2 = ax1.plot(times, erders_c2_avg[:, i],
linewidth=self.line_width,
color=[24 / 255, 255 / 255, 73 / 255])
ax1.set_xlim([t1, t2])
ax1.set_ylim([mi_ - off_, ma_ + off_])
ax1.set_title(ch_to_plot[n], fontsize=self.label_size)
ax1.set_ylabel(r'ERD/ERS (%)', fontsize=self.label_size)
ax1.legend([labels_info[str(0)], labels_info[str(1)]],
fontsize=self.label_size, loc='upper left')
ax1.set_xlabel('Time (ms)', fontsize=self.label_size)
ax1.tick_params(axis='x', labelsize=self.axes_size)
ax1.tick_params(axis='y', labelsize=self.axes_size)
# Signed-r2
if "r2" in axs_to_plot[n]:
ax2 = axs_to_plot[n]['r2']
ax2.minorticks_on()
ax2.pcolormesh(times, range(2),
np.tile(trials_r2[:, i], reps=[2, 1]),
cmap='YlOrRd',
vmin=0)
ax2.set_xlim([t1, t2])
ax2.set_ylabel('$r^2$', fontsize=self.label_size)
ax2.set_xlabel('Time (ms)', fontsize=self.label_size)
ax2.tick_params(axis='x', labelsize=self.axes_size)
ax2.tick_params(axis='y', labelsize=self.axes_size)
ax2.get_yaxis().set_ticks([])
# P-value < 0.05
if "pval" in axs_to_plot[n]:
ax3 = axs_to_plot[n]['pval']
ax3.minorticks_on()
ax3.pcolormesh(times, range(2),
np.tile(trials_p_fdr[:, i] <= 0.05,
reps=[2, 1]),
cmap='binary',
vmin=0, vmax=0.05)
ax3.set_xlim([t1, t2])
ax3.set_ylabel('$p$-val', fontsize=self.label_size)
ax3.set_xlabel('Time (ms)', fontsize=self.label_size)
ax3.tick_params(axis='x', labelsize=self.axes_size)
ax3.tick_params(axis='y', labelsize=self.axes_size)
ax3.get_yaxis().set_ticks([])
return axs_to_plot
[docs] def plot_erd_ers_r2_topo(self, ch_to_plot, ax_to_plot=None,
welch_seg_len_pct=50,
welch_overlap_pct=75):
if self.raw_dataset is None:
raise Exception("Call MiPlots._extract_features() before plotting!")
if len(ch_to_plot) != 2:
raise Exception("We need exactly two channels to compute r2 topo!")
if ax_to_plot is None:
ax_to_plot = list()
for c in ch_to_plot:
fig = plt.figure(figsize=(5, 5), dpi=300)
ax_to_plot = fig.add_subplot(111)
lcha = self.raw_dataset.channel_set.l_cha
labels = self.track_info["mi_labels"]
labels_info = self.track_info["mi_labels_info"][0]
# Compute the PSD
trials_psd = None
new_fs = self.fs if self.target_fs is None else self.target_fs
for t in self.features:
# Compute PSD of the trial
welch_seg_len = np.round(
welch_seg_len_pct / 100 * t.shape[0]).astype(int)
welch_overlap = np.round(
welch_overlap_pct / 100 * welch_seg_len).astype(int)
t_freqs, t_psd = scisig.welch(t, fs=new_fs, nperseg=welch_seg_len,
noverlap=welch_overlap,
nfft=welch_seg_len, axis=0)
# Concatenate
t_psd = t_psd.reshape(1, t_psd.shape[0], t_psd.shape[1])
trials_psd = np.concatenate((trials_psd, t_psd), axis=0) if \
trials_psd is not None else t_psd
# Separate the classes
trials_psd_c1 = trials_psd[labels == 0, :, :]
trials_psd_c2 = trials_psd[labels == 1, :, :]
# Signed r2
trials_r2 = statistics.signed_r2(trials_psd_c1, trials_psd_c2,
signed=True, axis=0)
trials_r2 = np.mean(trials_r2, axis=0)
max_r2 = np.abs(np.max(trials_r2.flatten()))
# Topoplot
values = trials_r2.reshape(1, len(lcha))
topo_settings = {
"head_radius": 1.0,
"head_line_width": self.line_width * 2,
"interp_contour_width": self.line_width,
"interp_points": 500,
"cmap": "RdBu_r",
"clim": (-max_r2, max_r2)
}
topo = head_plots.TopographicPlot(
axes=ax_to_plot, channel_set=self.raw_dataset.channel_set,
**topo_settings
)
topo.update(values=values)
ax_to_plot.set_title("Signed $r^2$ (%s)" % ' vs. '.join(ch_to_plot),
fontsize=self.label_size)
# return ax_to_plot, handles["color-mesh"]
return ax_to_plot, topo.plot_handles["color-mesh"]
def _extract_erd_ers_features(files, ch_to_plot, order=1000, cutoff=[5, 35],
btype='bandpass', temp_filt_method='filtfilt',
w_epoch_t=(-1000, 6000), target_fs=None,
baseline_mode='trial', w_baseline_t=(-1000, 0),
norm='z'):
saved_args = locals()
del saved_args['files']
del saved_args['ch_to_plot']
# Load files
rec = Recording.load(files[0])
fs = rec.eeg.fs
channel_set = rec.eeg.channel_set
dataset = MIDataset(channel_set=channel_set, fs=rec.eeg.fs,
experiment_att_key='midata',
biosignal_att_key='eeg', experiment_mode='train')
for file in files:
dataset.add_recordings(Recording.load(file))
# Pre-processing
fir = ff.FIRFilter(order=order, cutoff=cutoff, btype=btype,
filt_method=temp_filt_method)
fir.fit(fs=fs)
for rec in dataset.recordings:
eeg = getattr(rec, dataset.biosignal_att_key)
eeg.signal = fir.transform(signal=eeg.signal)
eeg.signal = sf.car(signal=eeg.signal)
setattr(rec, dataset.biosignal_att_key, eeg)
# Feature extraction
feature_extractor = StandardFeatureExtraction(w_epoch_t=w_epoch_t,
target_fs=target_fs,
baseline_mode=baseline_mode,
w_baseline_t=w_baseline_t,
norm=norm)
features, track_info = feature_extractor.transform_dataset(dataset=dataset)
lcha = dataset.channel_set.l_cha
return features, track_info, fs, lcha, channel_set, saved_args
[docs]def plot_erd_ers_time(files, ch_to_plot, features=None, track_info=None,
fs=None, lcha=None, channel_set=None, mov_mean_ms=1000,
**kwargs):
"""Plotting function of ERD/ERS from motor imagery runs of MEDUSA.
Parameters
----------
files: list
List of paths pointing to MI files.
ch_to_plot: list
List with the labels of the channels to plot
"""
for key, value in kwargs.items():
globals()[key] = value
# Extract only if required
if features is None:
features, track_info, fs, lcha, channel_set, saved_args = \
_extract_erd_ers_features(
files, ch_to_plot, **kwargs
)
for key, value in saved_args.items():
globals()[key] = value
labels = track_info["mi_labels"]
# todo: hardcoded
labels_info = track_info["mi_labels_info"][0]
new_fs = fs if target_fs is None else target_fs
# # Baseline parameters
t_baseline = [w_baseline_t[0] - w_epoch_t[0],
w_baseline_t[1] - w_epoch_t[0]]
idx_baseline = np.round(np.array(t_baseline) * new_fs / 1000).astype(int)
# Separate the classes
trials_c1 = features[labels == 0, :, :]
trials_c2 = features[labels == 1, :, :]
# Compute the average power
p_c1 = np.power(trials_c1, 2)
p_c2 = np.power(trials_c2, 2)
p_c1_avg = np.mean(p_c1, axis=0)
p_c2_avg = np.mean(p_c2, axis=0)
# Compute the reference power for each channel
r_c1_mean = np.mean(p_c1_avg[idx_baseline[0]:idx_baseline[1], :], axis=0)
r_c2_mean = np.mean(p_c2_avg[idx_baseline[0]:idx_baseline[1], :], axis=0)
# Compute ERD/ERS
ERDERS_c1 = 100 * (p_c1_avg - r_c1_mean) / r_c1_mean
ERDERS_c2 = 100 * (p_c2_avg - r_c2_mean) / r_c2_mean
# TODO: Cambiar por https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.uniform_filter1d.html
ERDERS_c1_smooth = uniform_filter1d(ERDERS_c1,
int(np.floor(
mov_mean_ms * new_fs / 1000)),
axis=0, mode='mirror')
ERDERS_c2_smooth = uniform_filter1d(ERDERS_c2,
int(np.floor(
mov_mean_ms * new_fs / 1000)),
axis=0, mode='mirror')
# Signed r2 for the power
p_c1_marg = 100 * (p_c1 - r_c1_mean / p_c1.shape[0]) / r_c1_mean
p_c2_marg = 100 * (p_c2 - r_c2_mean / p_c2.shape[0]) / r_c2_mean
trials_r2 = statistics.signed_r2(p_c1_marg, p_c2_marg, signed=False, axis=0)
if mov_mean_ms != 0:
trials_r2 = uniform_filter1d(trials_r2,
int(np.floor(mov_mean_ms * new_fs / 1000)),
axis=0, mode='mirror')
# Plotting
times = np.linspace(w_epoch_t[0], w_epoch_t[1], ERDERS_c1_smooth.shape[0])
# Plot
left = 0.1
bottom = 0
width = 0.8
height_psd = 0.6
height_r2 = 0.06
height_cbar = 0.06
gap = 0.12
figs = list()
for n in range(len(ch_to_plot)):
fig = plt.figure()
ax1 = fig.add_axes(
[left, bottom + height_r2 + height_cbar + gap, width, height_psd],
xticklabels=[])
ax2 = fig.add_axes([left, bottom + height_cbar + gap, width, height_r2],
yticklabels=[])
if ch_to_plot[n] not in ch_to_plot:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = lcha.index(ch_to_plot[n])
# ERD/ERS(%)
ax1.minorticks_on()
ax1.grid(visible=True, which='minor', color='#ededed', linestyle='--')
ax1.grid(visible=True, which='major')
ax1.axvline(x=0, color='k', linestyle='--', label='_nolegend_')
ax1.axvspan(w_baseline_t[0], w_baseline_t[1], alpha=0.1,
facecolor='gray', label='_nolegend_')
ax1.plot(times, ERDERS_c1_smooth[:, i], linewidth=2,
color=[255 / 255, 174 / 255, 0 / 255])
ax1.plot(times, ERDERS_c2_smooth[:, i], linewidth=2,
color=[24 / 255, 255 / 255, 73 / 255])
ax1.set_xlim(w_epoch_t)
ax1.title.set_text(ch_to_plot[n])
ax1.set_ylabel(r'ERD/ERS (%)')
ax1.legend([labels_info[str(0)], labels_info[str(1)]])
# Signed-r2
ax2.pcolormesh(times, range(2), np.tile(trials_r2[:, i], reps=[2, 1]),
cmap='YlOrRd')
ax2.axvline(x=0, color='k', linestyle='--', label='_nolegend_')
ax2.set_ylabel('$r^2$')
ax2.set_xlabel('Time (ms)')
figs.append(fig)
return figs
[docs]def plot_erd_ers_freq(files, ch_to_plot, features=None, track_info=None,
fs=None, lcha=None, channel_set=None,
welch_seg_len_pct=50,
welch_overlap_pct=75, mov_mean_hz=0,
**kwargs):
# TODO: More options!...
# TODO: Left and right classes labels are hardcoded!
""" This function depicts the ERD/ERS events of MI BCIs over the frequency
spectrum.
Parameters
----------
files: list
List of paths pointing to MI files.
ch_to_plot: list
List with the labels of the channels to plot
"""
for key, value in kwargs.items():
globals()[key] = value
# Extract only if required
if features is None:
features, track_info, fs, lcha, channel_set, saved_args = \
_extract_erd_ers_features(
files, ch_to_plot, **kwargs
)
for key, value in saved_args.items():
globals()[key] = value
labels = track_info["mi_labels"]
# todo: hardcoded
labels_info = track_info["mi_labels_info"][0]
# Compute the PSD
trials_psd = None
new_fs = fs if target_fs is None else target_fs
for t in features:
# Compute PSD of the trial
welch_seg_len = np.round(welch_seg_len_pct / 100 * t.shape[0]).astype(
int)
welch_overlap = np.round(
welch_overlap_pct / 100 * welch_seg_len).astype(int)
welch_ndft = welch_seg_len
t_freqs, t_psd = scisig.welch(t, fs=new_fs, nperseg=welch_seg_len,
noverlap=welch_overlap,
nfft=welch_ndft, axis=0)
# Concatenate
t_psd = t_psd.reshape(1, t_psd.shape[0], t_psd.shape[1])
trials_psd = np.concatenate((trials_psd, t_psd),
axis=0) if trials_psd is not None else t_psd
# Separate the classes
trials_psd_c1 = trials_psd[labels == 0, :, :]
trials_psd_c2 = trials_psd[labels == 1, :, :]
# Signed r2
trials_r2 = statistics.signed_r2(trials_psd_c1, trials_psd_c2, signed=False,
axis=0)
if mov_mean_hz != 0:
size = int(
trials_psd_c1.shape[1] / (cutoff[1] - cutoff[0]) * mov_mean_hz)
trials_r2 = uniform_filter1d(trials_r2, size,
axis=0, mode='nearest')
# Mean PSD
m_psd_c1 = np.mean(trials_psd_c1, axis=0)
m_psd_c2 = np.mean(trials_psd_c2, axis=0)
# Plot ranges
freqs = np.linspace(0, new_fs / 2, len(m_psd_c1))
lims = [0, new_fs / 2]
if btype == 'bandpass':
lims = [cutoff[0],
cutoff[1]]
elif btype == 'highpass':
lims[0] = cutoff[0]
elif btype == 'lowpass':
lims[1] = cutoff[1]
# Plot
left = 0.1
bottom = 0
width = 0.8
height_psd = 0.6
height_r2 = 0.06
height_cbar = 0.06
gap = 0.12
figs = list()
for n in range(len(ch_to_plot)):
fig = plt.figure()
ax1 = fig.add_axes(
[left, bottom + height_r2 + height_cbar + gap, width, height_psd],
xticklabels=[])
ax2 = fig.add_axes([left, bottom + height_cbar + gap, width, height_r2],
yticklabels=[])
if ch_to_plot[n] not in ch_to_plot:
raise ValueError('Channel ' + ch_to_plot[n] + ' is missing!')
i = lcha.index(ch_to_plot[n])
# Individual curves
# for j in range(trials_psd_c1.shape[0]):
# plt.plot(freqs, trials_psd_c1[j,:,i], linewidth=0.5,
# color=[255/255, 174/255, 0/255], alpha=0.5)
# for j in range(trials_psd_c2.shape[0]):
# plt.plot(freqs, trials_psd_c2[j,:,i], linewidth=0.5,
# color=[24/255, 255/255, 73/255], alpha=0.5)
# Averaged curves
ax1.minorticks_on()
ax1.grid(b=True, which='minor', color='#ededed', linestyle='--')
ax1.grid(b=True, which='major')
ax1.plot(freqs, m_psd_c1[:, i], linewidth=2,
color=[255 / 255, 174 / 255, 0 / 255])
ax1.plot(freqs, m_psd_c2[:, i], linewidth=2,
color=[24 / 255, 255 / 255, 73 / 255])
ax1.set_xlim(lims)
ax1.title.set_text(ch_to_plot[n])
ax1.set_ylabel(r'PSD ($uV^2/Hz$)')
ax1.legend([labels_info[str(0)], labels_info[str(1)]])
# Signed-r2
ax2.pcolormesh(freqs, range(2), np.tile(trials_r2[:, i], reps=[2, 1]),
cmap='YlOrRd', vmin=0)
ax2.set_xlim(lims)
ax2.set_ylabel('$r^2$')
ax2.set_xlabel('Frequency (Hz)')
figs.append(fig)
return figs
[docs]def plot_r2_topoplot(files, ch_to_plot, features=None, track_info=None,
fs=None, lcha=None, channel_set=None,
welch_seg_len_pct=50,
welch_overlap_pct=75, background=False, **kwargs):
for key, value in kwargs.items():
globals()[key] = value
# Extract only if required
if features is None:
features, track_info, fs, lcha, channel_set, saved_args = \
_extract_erd_ers_features(
files, ch_to_plot, **kwargs
)
for key, value in saved_args.items():
globals()[key] = value
labels = track_info["mi_labels"]
# todo: hardcoded
labels_info = track_info["mi_labels_info"][0]
new_fs = fs if target_fs is None else target_fs
# Compute the PSD
trials_psd = None
for t in features:
# Compute PSD of the trial
welch_seg_len = np.round(welch_seg_len_pct / 100 * t.shape[0]).astype(
int)
welch_overlap = np.round(
welch_overlap_pct / 100 * welch_seg_len).astype(int)
welch_ndft = welch_seg_len
t_freqs, t_psd = scisig.welch(t, fs=new_fs, nperseg=welch_seg_len,
noverlap=welch_overlap,
nfft=welch_ndft, axis=0)
# Concatenate
t_psd = t_psd.reshape(1, t_psd.shape[0], t_psd.shape[1])
trials_psd = np.concatenate((trials_psd, t_psd),
axis=0) if trials_psd is not None else t_psd
# Separate the classes
trials_psd_c1 = trials_psd[labels == 0, :, :]
trials_psd_c2 = trials_psd[labels == 1, :, :]
# Signed r2
trials_r2 = statistics.signed_r2(trials_psd_c1, trials_psd_c2, signed=True,
axis=0)
trials_r2 = np.mean(trials_r2, axis=0)
# Topoplot
values = trials_r2.reshape(1, len(lcha))
fig, _, _ = topographic_plots.plot_topography(channel_set,
values, cmap='RdBu',
background=background,
show=False)
return fig