"""
Created on Thu Aug 25 17:09:18 2022
@author: Diego Marcos-Martínez
"""
import numpy as np
import matplotlib.pyplot as plt
[docs]def time_plot(epoch, fs=1.0, ch_labels=None):
"""
Parameters
---------
epoch: numpy array
Signal with shape of [n_epochs,n_samples, n_channels] or
[n_samples, n_channels]
fs: float
Sampling rate. Value 1 as default
ch_labels: list of strings or None
List containing the channel labels
"""
try:
# Check if signals is divided in epochs
if len(epoch.shape) == 2:
epoch = epoch[np.newaxis,:,:]
blocks, samples_per_block, channels = epoch.shape
if ch_labels is None:
ch_labels = [f"Channel {idx}" for idx in range(channels)]
epoch_c = __reshape_signal(epoch)
channel_offset = np.zeros(epoch_c.shape[1])
if len(channel_offset) > 1:
channel_offset[1:] = np.max(np.max(epoch_c[:, 1:], axis=0) -
np.min(epoch_c[:, :-1], axis=0))
channel_offset = np.cumsum(channel_offset)
epoch_c = epoch_c - channel_offset
max_val, min_val = epoch_c.max(), epoch_c.min()
display_times = np.linspace(0, int(epoch_c.shape[0] / fs),
epoch_c.shape[0])
# Plot
plt.plot(display_times, epoch_c, 'k', linewidth=0.5)
plt.yticks(-channel_offset, labels=ch_labels)
vertical_lines = np.empty((blocks - 1, 2, 50))
for block in range(blocks - 1):
vertical_lines[block, :, :] = np.asarray(
[np.ones(50) * (block + 1) * int(samples_per_block / fs),
np.linspace(min_val, max_val, 50)])
plt.plot(vertical_lines[block, 0, :], vertical_lines[block, 1, :],
'--', color='red', linewidth=1.5)
plt.show(block=True)
except Exception as e:
print(e)
def __reshape_signal(epochs):
"""This is an auxiliary function than reshapes the signal it is divided in
epochs in order to plot it in a row"""
try:
epoch_c = epochs.copy()
blocks, samples_per_block, channels = epoch_c.shape
epoch_c = np.reshape(epoch_c, (int(blocks * samples_per_block), channels))
return epoch_c
except Exception as e:
print(e)