Skip to content

GradsScalarHandler logs 0 gradients if default update function is used #2459

@egaznep

Description

@egaznep

🐛 Bug description

Logging the gradients per epoch / iteration is a useful way to debug an under-performing model. Ignite provides an easy-to-use tensorboard_logger handler, an example accessible from ignite.contrib.handlers.tensorboard_logger.GradsScalarHandler. However, the default update function used by Engines generated by create_supervised_trainer zero the gradients before terminating, causing the handler to log zeroed out gradients all the time.

Steps to reproduce:
My code is too complicated at the momemt to provide a clear insight, and I am limited by time to provide a minimal (not-)working example, so I will provide abstracted steps.

  1. Generate an Engine / DeterministicEngine on an arbitrary problem by the create_supervised_trainer method.
  2. Establish a TensorboardLogger and a .
  3. Attach a GradsScalarHandler / your choice of a gradient logger. Also log the training loss or some other metric.
  4. Start the training, check tensorboard and see the constant-0 gradient norms / gradients, despite the losses/metrics implying some sort of improvement/learning takes place.

Solution proposal

Preserving the gradients until epoch end is tricky, but not required for my purposes. If we are OK with using Events.ITERATION_COMPLETED as a cue to log gradients, then we can simply modify the default update functions as follows:

(assuming engine.state.iteration counts from 1).

    def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
        if (engine.state.iteration + 1) % gradient_accumulation_steps == 0:
            optimizer.zero_grad()
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if engine.state.iteration % gradient_accumulation_steps == 0:
            optimizer.step()
        return output_transform(x, y, y_pred, loss)

This way, upon completion of update and at the moment of Events.ITERATION_COMPLETED firing, there will be some non-zero gradients available to be logged.

Environment (latest version of Ignite still has the same bug)

  • PyTorch Version (e.g., 1.4): 1.10.1
  • Ignite Version (e.g., 0.3.0): 0.4.7
  • OS (e.g., Linux):
  • How you installed Ignite (conda, pip, source): conda
  • Python version: 3.9.7
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions