pip install lightning-trainer-utils- The model wrapper uses the forward function as follows:
output = self.model(**x, **self.forward_kwargs)
return ModelOuput(**output)It expects batch as dict and returns a dict with keys [loss, report, output].
- ML model should return a dict with the following keys:
lossreportoutput[optional]
batch_step = num_samples / (batch_size * num_devices) trainer_global_step = num_samples / (batch_size * num_devices * grad_accumulation)
SaveCheckpoint also use trainer_global_step.