"""
Created on Thu Aug 25 17:09:18 2022
Edited on Mon Jan 09 14:00:00 2023
@author: Diego Marcos-Martínez
"""
import warnings
# External imports
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.widgets import Slider
from matplotlib.widgets import Button
# Medusa imports
from medusa.utils import check_dimensions
def __plot_epochs_lines(ax, blocks, samples_per_block, fs, min_val, max_val):
"""Aux function to plot vertical lines in case of signal is divided in two
or more epochs"""
t_ = np.arange(1,blocks) * int(samples_per_block / fs)
ax.vlines(t_, min_val, max_val, colors='k',
linewidth=2, linestyles='solid')
def __plot_events_lines(ax, events_dict, min_val, max_val, display_times):
"""Aux function to plot vertical lines corresponding to marked events"""
# Check errors
if not isinstance(events_dict, dict):
raise ValueError("'events_dict' must be a dict."
"Please, read carefully the time_plot documentation"
"to know how to define 'events_dict' properly. ")
if not 'events' in events_dict.keys():
raise ValueError("'events_dict' must have 'events' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'events_dict' properly.")
if not 'events_labels' in events_dict.keys():
raise ValueError("'events_dict' must have 'event_labels' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'events_dict' properly.")
if not 'events_times' in events_dict.keys():
raise ValueError("'events_dict' must have 'event_times' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'events_dict' properly.")
events_names = []
for key_event in list(events_dict['events'].keys()):
events_names.append(events_dict['events'][key_event][
'desc-name'])
legend_lines = {}
previous_conditions = None
cmap = plt.get_cmap('rainbow')(np.linspace(0,1,len(events_names)))
events_order = np.array(events_dict['events_labels'])
if ax.legend_ is not None:
handles, labels = ax.get_legend_handles_labels()
previous_conditions = list(np.unique(labels))
events_timestamps = np.array(events_dict['events_times'])
# Check if events_timestamps are referenced to recording start
if np.any(events_timestamps > display_times[-1]):
raise ValueError("Incorrect format of events_timestamps. "
"The values must be referenced to the beginning "
"of the record, so that the first timestamp has value 0.")
for event_idx, event_type in enumerate(set(events_order)):
t_ = events_timestamps[events_order == event_type]
l = ax.vlines(t_, min_val, max_val, colors=cmap[event_idx],
linewidth=2, linestyles='dashed',
label=events_names[event_idx])
if event_type not in legend_lines.keys():
legend_lines[event_type] = l
# Create legend above the plot
if previous_conditions is not None:
previous_handles = ax.legend_.legendHandles
for legend_line in list(legend_lines.values()):
previous_handles.append(legend_line)
previous_conditions.append(legend_line._label)
ax.legend(handles=previous_handles, labels=previous_conditions,
loc='upper center', bbox_to_anchor=(0.5, 1.15),
ncol=3, fancybox=True, shadow=True)
else:
ax.legend(handles=list(legend_lines.values()), loc='upper center',
bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True,
shadow=True)
def __plot_condition_shades(ax, conditions_dict, display_times, min_val, max_val):
"""Aux function to plot background shades to corresponding to different
conditions during the signal recording. """
# Check errors
if not isinstance(conditions_dict, dict):
raise ValueError("'conditions_dict' must be a dict."
"Please, read carefully the time_plot documentation"
"to know how to define 'conditions_dict' properly. ")
if not 'conditions' in conditions_dict.keys():
raise ValueError("'conditions_dict' must have 'conditions' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'conditions_dict' properly.")
if not 'conditions_labels' in conditions_dict.keys():
raise ValueError("'conditions_dict' must have 'condition_labels' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'conditions_dict' properly.")
if not 'conditions_times' in conditions_dict.keys():
raise ValueError("'conditions_dict' must have 'condition_times' key."
"Please, read carefully the time_plot documentation"
"to know how to define 'conditions_dict' properly.")
conditions_names = []
legend_patches = {}
for key_condition in list(conditions_dict['conditions'].keys()):
conditions_names.append(conditions_dict['conditions'][key_condition][
'desc-name'])
condition_timestamps = conditions_dict['conditions_times']
# Check if timestamps are an iterable object
if not isinstance(condition_timestamps,np.ndarray) and \
not isinstance(condition_timestamps,list):
condition_timestamps = np.asarray([condition_timestamps])
labels_order = np.array([conditions_dict['conditions_labels']])
else:
labels_order = np.array(conditions_dict['conditions_labels'])
# Check if timestamps are referenced to recording start
if np.any(condition_timestamps > display_times[-1]):
raise ValueError("Incorrect format of condition_timestamps. "
"The values must be referenced to the beginning "
"of the record, so that the first timestamp has value 0.")
# Check if all conditions have a start and an end and fix it if not
c_idx = 0
corrected = False
while c_idx < len(labels_order)-1:
if (labels_order[c_idx] != labels_order[c_idx+1]):
if c_idx != 0:
if labels_order[c_idx] != labels_order[c_idx-1]:
labels_order = np.insert(labels_order, c_idx + 1,
labels_order[c_idx])
condition_timestamps = np.insert(condition_timestamps,
c_idx + 1,
condition_timestamps[
c_idx + 1])
corrected = True
else:
labels_order = np.insert(labels_order,c_idx+1,labels_order[c_idx])
condition_timestamps = np.insert(condition_timestamps,
c_idx+1,
condition_timestamps[c_idx+1])
corrected = True
c_idx += 1
if labels_order[-1] != labels_order[-2]:
labels_order = np.append(labels_order,labels_order[-1])
condition_timestamps = np.append(condition_timestamps,
display_times[-1])
corrected = True
if corrected:
warnings.warn("The dictionary of conditions does not follow the "
"correct format ([Start condition X, End condition X,"
" Start condition Y, End condition Y ...]). "
"The labels and timestamps vector has been "
"automatically corrected. Check that the OK "
"is correct. ")
cmap = plt.get_cmap('jet')(np.linspace(0,1,len(conditions_names)))
for condition_margin_idx in np.arange(0,len(condition_timestamps),2):
l = ax.fill_betweenx([min_val,max_val],
condition_timestamps[condition_margin_idx],
condition_timestamps[condition_margin_idx+1],
color= cmap[labels_order[condition_margin_idx]],
alpha=0.3,
label=np.array(conditions_names)[np.array(
labels_order[condition_margin_idx])])
if np.array(conditions_names)[
np.array(labels_order[
condition_margin_idx])] not in legend_patches.keys():
legend_patches.update(
{np.array(conditions_names)[np.array(
labels_order[condition_margin_idx])]: l})
# Create legend above the plot
ax.legend(handles=list(legend_patches.values()),
loc='upper center', bbox_to_anchor=(0.5, 1.15),
ncol=3, fancybox=True, shadow=True)
def __reshape_signal(epochs):
"""Aux function than reshapes the signal if it is divided in
epochs in order to plot it in a row"""
epoch_c = epochs.copy()
blocks, samples_per_block, channels = epoch_c.shape
epoch_c = np.reshape(epoch_c,
(int(blocks * samples_per_block), channels))
return epoch_c
[docs]def time_plot(signal, fs=1.0, ch_labels=None, time_to_show=None,
ch_to_show=None, ch_offset=None, color='k',
conditions_dict=None, events_dict=None, show_epoch_lines=True,
show=False, fig=None, axes=None):
"""
Parameters
---------
signal: numpy ndarray
Signal with shape of [n_epochs,n_samples, n_channels] or
[n_samples, n_channels]
fs: float
Sampling rate. Value 1 as default
ch_labels: list of strings or None
List containing the channel labels
time_to_show: float or None
Width of the time window displayed. If time_to_show value is greater than
the entire signal duration, this will be set as new time_to_show value.
If None, time_to_show value will be chosen between the minimum value of
the following windows: five seconds or the entire duration of the signal.
ch_to_show: int or None
Number of channels depicted in the plot. This parameter must be less or
equal to the number of channels available in the recording. If None,
this parameter is set as the total number of channels.
ch_offset: flot or None
Amplitude value to compute the offset of each channel. If None, the value
is automatically calculated from signal values.
color: string or tuple
Color of the signal line. It is plotted in black by default.
conditions_dict: dict
Dictionary with the following structure:
{'conditions':{'con_1':{'desc-name':'Condition 1','label':0},
'con_2':{'desc-name':'Condition 2','label':1}},
'condition_labels': [0,0,1,1,0,0],
'condition_times': [0,14,14,28,28,35]}
In this example, the sub-dictionary 'conditions' include each condition
with a descriptor name ('desc-name') which will be show in the time-plot
legend, and the label to identify the condition. For its part,
'condition_labels' must be a list containing the order of start and end
of each condition. Finally, 'condition_times' value must be a list
with the same length as 'condition_labels' containing the time stamps
(in seconds) related with the start and the end of each condition. Note
that these time stamps must be referenced to the start of the signal
recording (the value 14 in the example means the 14th second after the
start of recording). Note that as the end of a conditions coincides
with the start of the following condition, the same time stamps must be
included twice (see 14 and 28 values in 'condition_times' in the
example).
events_dict:
Dictionary with the following structure:
{'events':{'event_1':{'desc-name':'Event 1','label':0},
'event_2':{'desc-name':'Event 2','label':1}},
'event_labels': [0,1,1,1,0,1],
'event_times': [0,14,15.4,28,2,35]}
In this example, the sub-dictionary 'events' include each event with a
descriptor name ('desc-name') which will be show in the time-plot
legend, and the label to identify the event. For its part,
'event_labels' must be a list containing the order in which each event
ocurred. Finally, 'condition_times' value must be a list
with the same length as 'event_times' containing the time stamps
(in seconds) related with each event. Note that, as in 'conditions_dict'
argument, these time stamps must be referenced to the start of the
signal recording.
show_epoch_lines: bool
If signal is divided in epochs and the parameter value is True, vertical
dotted red lines will be plotted, splitting the epochs. Otherwise, they
will not be plotted. True is the default value.
show: bool
Show matplotlib figure
fig: matplotlib.pyplot.figure or None
If a matplotlib figure is specified, the plot is displayed inside it.
Otherwise, the plot will generate a new figure.
axes: matplotlib.pyplot.axes or None
If a matplotlib axes are specified, the plot is displayed inside it.
Otherwise, the plot will generate a new figure.
Notes
---------
If time_to_show or ch_to_show parameters are defined and the signal is
partially represented, vertical and horizontal sliders will be added to
control the represented channels and time window, respectively. Vertical
slider can be controlled by pressing up and dow arrows, and by dragging the
marker. For its part, horizontal slider van be controlled by pressing
right or left arrow, and by dragging the marker.
"""
# Check signal dimensions
signal = check_dimensions(signal)
# Get signal dimensions
blocks, samples_per_block, channels = signal.shape
# Check if there are channel labels
if ch_labels is None:
ch_labels = [f"Channel {idx}" for idx in range(channels)]
else:
if not isinstance(ch_labels, list):
raise ValueError("Channel labels ('ch_labels') must be entered"
"as a list.")
epoch_c = __reshape_signal(signal)
del signal
# Set maximum length of x-axis to be displayed
if time_to_show is None:
# The default time window is 5 seconds
time_to_show = min(5, epoch_c.shape[0] / fs)
# Set offset between channels
if ch_offset is None:
ch_offset = time_to_show * 2 * np.std(np.abs(
epoch_c.copy().ravel()))
offset_values = np.arange(channels) * ch_offset
epoch_c = epoch_c - offset_values
ch_off = offset_values
del ch_offset, offset_values
max_val, min_val = epoch_c.max(), epoch_c.min()
# Define times vector
display_times = np.linspace(0, int(epoch_c.shape[0] / fs),
epoch_c.shape[0])
# Initialize plot
if fig is None:
fig = plt.figure()
if axes is None:
axes = fig.add_subplot(111)
fig.patch.set_alpha(0)
axes.set_alpha(0)
fig.subplots_adjust(left=0.15, bottom=0.15, right=0.85)
ch_slider, time_slider = None, None
# Set maximum length of x-axis to be displayed
if ch_to_show is None:
# The default time window is 5 seconds
ch_to_show = len(ch_labels)
else:
if not isinstance(ch_to_show, int):
raise ValueError("Channel to show ('ch_to_show') must be an "
"integer value.")
if ch_to_show > len(ch_labels):
ch_to_show = len(ch_labels)
raise Warning("Entered a channel to show ('ch_to_show') value "
"greater than the number of channels in recording."
"This parameter will be set equal to the number of "
"channels.")
# Create a time slider only if the time window to show is less than the
# signal duration
if time_to_show < epoch_c.shape[0] / fs:
max_x = int(time_to_show * fs)
# Adjust the main plot to make room for the sliders
ax_time = fig.add_axes([0.15, 0.02, 0.70, 0.03])
# Define the max value of slider
max_val_time_slider = display_times[-1] - max_x / fs
# Define slider
time_slider = Slider(ax=ax_time, label='', valmin=0,
valmax=max_val_time_slider, valinit=0,
color='k')
time_slider.valtext.set_visible(False)
# Function to be call everytime the slider is moved
def __update_time(val):
# Update x-axis
axes.set_xlim(val, max_x / fs + val)
# Update canvas
fig.canvas.draw()
# Assign the update function to the slider
time_slider.on_changed(__update_time)
else:
# The time window to show is the whole signal
max_x = epoch_c.shape[0]
# Create a channel slider only if the channel window to show is less
# than the total number of channels
if ch_to_show < epoch_c.shape[1]:
max_y = ch_to_show
# Adjust the main plot to make room for the sliders
ax_ch = fig.add_axes([0.86, 0.15, 0.02, 0.73])
# Define the max value of slider
max_val_ch_slider = epoch_c.shape[1] - ch_to_show
# Define slider
ch_slider = Slider(ax=ax_ch, label='', valmin=-max_val_ch_slider,
valmax=0, valinit=0, valstep=1,
color='k', orientation='vertical')
ch_slider.valtext.set_visible(False)
# Function to be call everytime the slider is moved
def __update_ch(val):
# Update y-axis
axes.set_ylim(
-ch_off[max_y - val - 1] - 0.5 * ch_off[1],
-ch_off[-val] + 0.5 * ch_off[1])
# Update canvas
fig.canvas.draw()
# Assign the update function to the slider
ch_slider.on_changed(__update_ch)
else:
# The time window to show is the whole signal
max_y = epoch_c.shape[1]
# Allow sliders to be controlled by arrow keys
def on_key(event):
if event.key == 'up':
if ch_slider is not None:
if ch_slider.val != 0:
ch_slider.set_val(ch_slider.val + 1)
elif event.key == 'down':
if ch_slider is not None:
if ch_slider.val != -max_val_ch_slider:
ch_slider.set_val(ch_slider.val - 1)
elif event.key == 'right':
if time_slider is not None:
if time_slider.val < max_val_time_slider:
if time_slider.val + 1 > max_val_time_slider:
time_slider.set_val(max_val_time_slider)
else:
time_slider.set_val(time_slider.val + 1)
elif event.key == 'left':
if time_slider is not None:
if time_slider.val > 0:
if time_slider.val - 1 < 0:
time_slider.set_val(0)
else:
time_slider.set_val(time_slider.val - 1)
fig.canvas.mpl_connect('key_press_event', on_key)
# Call the aux function to plot conditions
if conditions_dict is not None:
__plot_condition_shades(axes, conditions_dict, display_times,
min_val, max_val)
# Call the aux function to plot vertical lines to mark the events
if events_dict is not None:
__plot_events_lines(axes, events_dict, min_val, max_val, display_times)
# Plot the signal
axes.plot(display_times, epoch_c, color, linewidth=0.5)
axes.set_yticks(-ch_off, labels=ch_labels)
if len(ch_off) > 1:
axes.set_ylim(-ch_off[max_y - 1] - 0.5 * ch_off[1],
-ch_off[0] + 0.5 * ch_off[1])
axes.set_xlim(0, display_times[max_x])
axes.set_xlabel('Time (s)')
# Call the aux function to plot vertical lines to split signal in epochs
if show_epoch_lines:
__plot_epochs_lines(axes, blocks, samples_per_block, fs,
min_val, max_val)
if show:
plt.show()
return fig, axes
if __name__ == "__main__":
""" Example of use: """
from medusa.components import Recording
from medusa.meeg import meeg
import medusa.frequency_filtering as ff
fs = 256
T = 60
t = np.arange(0, T, 1 / fs)
l_cha = ['F7', 'F3', 'FZ', 'F4', 'F8', 'FCz', 'C3', 'CZ', 'C4', 'CPz', 'P3',
'PZ', 'P4',
'PO7', 'POZ', 'PO8']
A = 1 # noise amplitude
sigma = 0.5 # Gaussian noise variance
f = 5 # frequency of sinusoids (Hz)
ps = np.linspace(0, -np.pi / 2, len(l_cha)) # Phase differences
np.random.seed(0)
# Define signal
signal = np.empty((len(t), len(l_cha)))
for c in range(len(l_cha)):
signal[:, c] = np.sin(2 * np.pi * f * t - ps[c]) + A * np.random.normal(
0, sigma, size=t.shape)
signal = signal.reshape((10, int(signal.shape[0] / 10), signal.shape[1]))
# Define events and conditions dicts
e_dict = {'events': {'event_1': {'desc-name': 'Event 1', 'label': 0},
'event_2': {'desc-name': 'Event 2', 'label': 1},
'event_3': {'desc-name': 'Event 3', 'label': 2},
'event_4': {'desc-name': 'Event 4', 'label': 3},
'event_5': {'desc-name': 'Event 5', 'label': 4}},
'events_labels': [0, 1, 1, 2, 0, 1, 3, 0, 1, 4],
'events_times': [5, 14, 15.4, 28, 2, 35, 43, 49, 53, 58.5]}
c_dict = {'conditions': {'con_1': {'desc-name': 'Condition 1', 'label': 0},
'con_2': {'desc-name': 'Condition 2', 'label': 1}},
'conditions_labels': [0, 0, 1, 1, 0, 0, 1, 1, 0, 0 ],
'conditions_times': [0, 14, 14, 28, 28, 35, 35, 50, 50, 60]}
# Initialize TimePlot instance
time_plot(signal=signal,fs=fs,ch_labels=l_cha,time_to_show=None,
ch_to_show=None,ch_offset=None,conditions_dict=c_dict,
events_dict=e_dict,show_epoch_lines=True,show=True)