-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
99 lines (86 loc) · 3.66 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import scipy.stats as stat
from misc import load_pkl, make_path, select_files
from epoch import Epoched
import params
import warnings
import os
import sys
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
def_cols = plt.rcParams['axes.prop_cycle'].by_key()['color']
def_style = ['-' for i in range(10)]
def plot_conds(epoched, conds_to_plot='all', plot_colors=def_cols,
plot_style=def_style, plot_error=True, sample_rate=250,
back_time=60, plot_title='Pupil Diameter',
plot_fname='pupil_diameter_plot',
out_dir='', base_name='', y_label='',
plot_nums=True, **params):
'''
Averages across trials in each condition and saves a plot
Arguments:
epoched: An Epoched object, like the one one outputed by epoch()
conds_to_plot: A list of indices of conditions to plot or 'all'
To plot all conditions
plot_colors: list of colors in hex format. E.g. '#00FF00'
plot_style: list of style specs like ['-', ':']
plot_error: True or False
sample_rate: Sample rate of the eye tracker (hz)
back_time: Time to plot before event (ms)
plot_title: Title of the plot
plot_fname: File name of the saved image
out_dir: Outputs will be saved to this directory
base_name: string is appended to output files
'''
print('\nPlotting...\n')
# Calulates the mean and errors (ignoring nans), supress warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
errors = stat.sem(epoched.matrix, axis=2, ddof=1, nan_policy='omit')
flattened = np.nanmean(epoched.matrix, axis=2)
if conds_to_plot == 'all':
conds_to_plot = [i for i in range(epoched.n_categs)]
assert len(plot_colors) >= len(
conds_to_plot), 'Require more colors than plotted conditions'
assert len(plot_style) > len(
conds_to_plot), 'Require more styles than plotted conditions'
x = [1000 * i / sample_rate -
back_time for i in range(epoched.total_samples)]
fig, ax = plt.subplots()
fig.set_size_inches(8, 5)
# i is the index of the condition, count is the order of plotting
for count, i in enumerate(conds_to_plot):
y = flattened[i, :]
ax.plot(x, y, label=epoched.names[i], color=plot_colors[count],
ls=plot_style[count])
# Plot error bars
if plot_error:
err = errors[i, :]
ax.fill_between(x, y - err, y + err,
alpha=0.25, color=plot_colors[count])
num_plotted = [x-y for x, y in zip(epoched.num_trials, epoched.num_rejected)]
if plot_nums:
plot_title += '\n # trials plotted: '+ str(num_plotted)
# Formatting...
ax.set_xlabel('Time (ms)')
ax.set_ylabel(y_label)
plt.axvline(x=0,lw=0.5, color='0')
plt.xlim((x[0], x[-1]))
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.9, box.height])
ax.legend(bbox_to_anchor=(1, 1.04), frameon=False)
plt.title(plot_title)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig(make_path(plot_fname, '.png', out_dir=out_dir, base_name=base_name),
bbox_inches='tight', dpi=600)
# TODO in the plot include info on how many trials were plotted.
# TODO Plot a vertical line @ 0.
if __name__ == '__main__':
params = params.get_params(sys.argv)
fname = select_files('.pkl')[0]
params['out_dir'] = os.path.dirname(fname)
print(fname)
e = load_pkl(fname)
plot_conds(e, **params)