-
-
Notifications
You must be signed in to change notification settings - Fork 654
Description
🐛 Bug description
Logging the gradients per epoch / iteration is a useful way to debug an under-performing model. Ignite provides an easy-to-use tensorboard_logger handler, an example accessible from ignite.contrib.handlers.tensorboard_logger.GradsScalarHandler
. However, the default update
function used by Engines generated by create_supervised_trainer
zero the gradients before terminating, causing the handler to log zeroed out gradients all the time.
Steps to reproduce:
My code is too complicated at the momemt to provide a clear insight, and I am limited by time to provide a minimal (not-)working example, so I will provide abstracted steps.
- Generate an
Engine
/DeterministicEngine
on an arbitrary problem by thecreate_supervised_trainer
method. - Establish a
TensorboardLogger
and a . - Attach a
GradsScalarHandler
/ your choice of a gradient logger. Also log the training loss or some other metric. - Start the training, check tensorboard and see the constant-0 gradient norms / gradients, despite the losses/metrics implying some sort of improvement/learning takes place.
Solution proposal
Preserving the gradients until epoch end is tricky, but not required for my purposes. If we are OK with using Events.ITERATION_COMPLETED
as a cue to log gradients, then we can simply modify the default update
functions as follows:
(assuming engine.state.iteration
counts from 1).
def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration + 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
return output_transform(x, y, y_pred, loss)
This way, upon completion of update
and at the moment of Events.ITERATION_COMPLETED
firing, there will be some non-zero gradients available to be logged.
Environment (latest version of Ignite still has the same bug)
- PyTorch Version (e.g., 1.4): 1.10.1
- Ignite Version (e.g., 0.3.0): 0.4.7
- OS (e.g., Linux):
- How you installed Ignite (
conda
,pip
, source): conda - Python version: 3.9.7
- Any other relevant information: