Skip to content

Commit c50b48d

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Plot histograms of parameters to tensorboard (facebookresearch#432)
Summary: Pull Request resolved: facebookresearch#432 - Plot the histogram of weights for every parameter in the model at the end of every train phase. - Updated the various scalars plotted to Tensorboard to have their own tags, just like "Speed" to organize things better Adding the activations and gradients is non-trivial since they depend on the input, so skipping that for now. Reviewed By: vreis Differential Revision: D20427992 fbshipit-source-id: d157f73eac3e910733f41cdccd40087431805b25
1 parent 9f405b2 commit c50b48d

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

classy_vision/hooks/tensorboard_plot_hook.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None:
6060
self.wall_times = []
6161
self.num_steps_global = []
6262

63+
if not is_master():
64+
return
65+
66+
# log the parameters before training starts
67+
if task.train and task.train_phase_idx == 0:
68+
for name, parameter in task.base_model.named_parameters():
69+
self.tb_writer.add_histogram(
70+
f"Parameters/{name}", parameter, global_step=-1
71+
)
72+
6373
def on_step(self, task: "tasks.ClassyTask") -> None:
6474
"""Store the observed learning rates."""
6575
if self.learning_rates is None:
@@ -92,27 +102,26 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None:
92102
logging.info(f"Plotting to Tensorboard for {phase_type} phase {phase_type_idx}")
93103

94104
phase_type = task.phase_type
95-
loss_key = f"{phase_type}_loss"
96-
learning_rate_key = f"{phase_type}_learning_rate_updates"
105+
learning_rate_key = f"Learning Rate/{phase_type}"
97106

98107
if task.train:
99-
for loss, learning_rate, global_step, wall_time in zip(
100-
task.losses, self.learning_rates, self.num_steps_global, self.wall_times
108+
for learning_rate, global_step, wall_time in zip(
109+
self.learning_rates, self.num_steps_global, self.wall_times
101110
):
102-
loss /= task.get_batchsize_per_replica()
103-
self.tb_writer.add_scalar(
104-
loss_key, loss, global_step=global_step, walltime=wall_time
105-
)
106111
self.tb_writer.add_scalar(
107112
learning_rate_key,
108113
learning_rate,
109114
global_step=global_step,
110115
walltime=wall_time,
111116
)
117+
for name, parameter in task.base_model.named_parameters():
118+
self.tb_writer.add_histogram(
119+
f"Parameters/{name}", parameter, global_step=phase_type_idx
120+
)
112121

113122
loss_avg = sum(task.losses) / (batches * task.get_batchsize_per_replica())
114123

115-
loss_key = "avg_{phase_type}_loss".format(phase_type=task.phase_type)
124+
loss_key = "Losses/{phase_type}".format(phase_type=task.phase_type)
116125
self.tb_writer.add_scalar(loss_key, loss_avg, global_step=phase_type_idx)
117126

118127
# plot meters which return a dict
@@ -122,13 +131,13 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None:
122131
continue
123132
for name, value in meter.value.items():
124133
if isinstance(value, float):
125-
meter_key = f"{phase_type}_{meter.name}_{name}"
134+
meter_key = f"Meters/{phase_type}/{meter.name}/{name}"
126135
self.tb_writer.add_scalar(
127136
meter_key, value, global_step=phase_type_idx
128137
)
129138
else:
130139
log.warn(
131-
f"Skipping meter name {meter.name}_{name} with value: {value}"
140+
f"Skipping meter name {meter.name}/{name} with value: {value}"
132141
)
133142
continue
134143

0 commit comments

Comments
 (0)