-
-
Notifications
You must be signed in to change notification settings - Fork 651
Description
🚀 Feature
The *_training_step
implementations used by create_supervised_trainer
expect only a single output from the model that's passed directly to the loss function, i.e.
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
But I find that commonly models return more than predictions, i.e. they may return logits
and embeddings
, or logits
and loss
or perhaps a dictionary
It would be nice if you could pass a transform that would apply to the output from the model in order to extract the predictions in the form required by the loss function, i.e.
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
outputs = model(x)
y_pred = model_transform(outputs)
loss = loss_fn(y_pred, y)
I have a current work-around - it required a lambda to handle the loss, and output_transform for each metric, i.e.
loss = functools.partial(F.cross_entropy, label_smoothing=params['train']['label_smoothing'])
trainer = create_supervised_trainer(model,
optimizer,
# model output is tuple, need to extract first element
loss_fn=lambda output, target : loss(output[0], target),
prepare_batch=prepare_batch,
device=params['train']['device'])
evaluator = create_supervised_evaluator(model,
metrics={
# model output is tuple, need to extract first element
'accuracy': Accuracy(output_transform=lambda output : (output[0][0], output[1])),
'loss': Loss(loss, output_transform=lambda output: (output[0][0], output[1]))},
prepare_batch=prepare_batch,
device=params['train']['device'])
It works fine but it was a bit fiddly to make sure I'd done it correctly, it would be a lot easier if I could just pass a single function to create_supervised_trainer
i.e. lambda output: output[0]