-
Notifications
You must be signed in to change notification settings - Fork 6
/
callbacks.py
102 lines (72 loc) · 3.84 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import visual_visdom
def _solver_loss_cb(log, visdom, model=None, tasks=None, iters_per_task=None, replay=False, progress_bar=True):
'''Initiates function for keeping track of, and reporting on, the progress of the solver's training.'''
def cb(bar, iteration, loss_dict, task=1):
'''Callback-function, to call on every iteration to keep track of training progress.'''
if task is None:
task = 0
# progress-bar
if progress_bar and bar is not None:
task_stm = "" if (tasks is None) else " Task: {}/{} |".format(task, tasks)
bar.set_description(
' <SOLVER> |{t_stm} training loss: {loss:.3} | accuracy: {prec:.3} |'
.format(t_stm=task_stm, loss=loss_dict['loss_total'], prec=loss_dict['accuracy'])
)
bar.update(1)
# log the loss of the solver (to visdom)
if (iteration % log == 0) and (visdom is not None):
plot_data = [loss_dict['loss_total']]
i = (task-1)*iters_per_task + iteration
visual_visdom.visualize_scalars(
scalars=plot_data, names=["Total loss"], iteration=i,
title="Solver loss", env=visdom["env"], ylabel="training loss"
)
# Return the callback-function.
return cb
def _task_loss_cb(model, test_datasets, log, visdom, iters_per_task, vis_name=""):
'''Initiates function for keeping track of, and reporting on, the progress of the solver's training.'''
def cb(iter, task=1):
'''Callback-function, to call on every iteration to keep track of training progress.'''
if task is None:
task = 0
iteration = (task-1)*iters_per_task + iter
if (iteration % log == 0) and (visdom is not None):
loss_dict = model.test(task, test_datasets, verbose=False)
while len(loss_dict["Accuracy"]) < len(test_datasets):
loss_dict["Accuracy"].append(0)
loss_dict["Task"] = range(len(test_datasets))
plot_data = loss_dict["Accuracy"]
names = ["task"+str(s+1) for s in loss_dict["Task"]]
if visdom is None:
return
visdom["values"].append({"iter": iteration, "acc": plot_data})
visual_visdom.visualize_scalars(
scalars=plot_data, names=names, iteration=iteration,
title="Task accuracy"+vis_name, env=visdom["env"], ylabel="accuracy per task"
)
# Return the callback-function.
return cb
def _generator_training_callback(log, visdom, model, tasks=None, iters_per_task=None, replay=False, progress_bar=True):
'''Initiates functions for keeping track of, and reporting on, the progress of the generator's training.'''
def cb(bar, iter, loss_dict, task=1):
'''Callback-function, to perform on every iteration to keep track of training progress.'''
iteration = iter
# progress-bar
if progress_bar and bar is not None:
task_stm = " Class: {} |".format(task)
bar.set_description(
' <GAN> |{t_stm} d cost: {loss:.3} | g cost: {prec:.3} |'
.format(t_stm=task_stm, loss=loss_dict['d_cost'], prec=loss_dict['g_cost'])
)
bar.update(1)
if visdom is None:
return
if (iteration % log == 0) and (visdom is not None):
plot_data = [loss_dict['d_cost'], loss_dict['g_cost']]
names = ['Discriminator cost', 'Generator cost']
visual_visdom.visualize_scalars(
scalars=plot_data, names=names, iteration=iteration,
title="GENERATOR: loss class{t}".format(t=task), env=visdom["env"], ylabel="training loss"
)
# Return the callback-function
return cb