diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index e57edf1523..3dd16e88ff 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -60,6 +60,16 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None: self.wall_times = [] self.num_steps_global = [] + if not is_master(): + return + + # log the parameters before training starts + if task.train and task.train_phase_idx == 0: + for name, parameter in task.base_model.named_parameters(): + self.tb_writer.add_histogram( + f"Parameters/{name}", parameter, global_step=-1 + ) + def on_step(self, task: "tasks.ClassyTask") -> None: """Store the observed learning rates.""" if self.learning_rates is None: @@ -92,8 +102,8 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None: logging.info(f"Plotting to Tensorboard for {phase_type} phase {phase_type_idx}") phase_type = task.phase_type - loss_key = f"{phase_type}_loss" - learning_rate_key = f"{phase_type}_learning_rate_updates" + loss_key = f"Losses/{phase_type}/step" + learning_rate_key = f"Learning Rate/{phase_type}" if task.train: for loss, learning_rate, global_step, wall_time in zip( @@ -109,10 +119,14 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None: global_step=global_step, walltime=wall_time, ) + for name, parameter in task.base_model.named_parameters(): + self.tb_writer.add_histogram( + f"Parameters/{name}", parameter, global_step=phase_type_idx + ) loss_avg = sum(task.losses) / (batches * task.get_batchsize_per_replica()) - loss_key = "avg_{phase_type}_loss".format(phase_type=task.phase_type) + loss_key = "Losses/{phase_type}/avg".format(phase_type=task.phase_type) self.tb_writer.add_scalar(loss_key, loss_avg, global_step=phase_type_idx) # plot meters which return a dict @@ -122,13 +136,13 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None: continue for name, value in meter.value.items(): if isinstance(value, float): - meter_key = f"{phase_type}_{meter.name}_{name}" + meter_key = f"Meters/{phase_type}/{meter.name}/{name}" self.tb_writer.add_scalar( meter_key, value, global_step=phase_type_idx ) else: log.warn( - f"Skipping meter name {meter.name}_{name} with value: {value}" + f"Skipping meter name {meter.name}/{name} with value: {value}" ) continue