Skip to content

Model Ouput Transform #2837

@david-waterworth

Description

@david-waterworth

🚀 Feature

The *_training_step implementations used by create_supervised_trainer expect only a single output from the model that's passed directly to the loss function, i.e.

    x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)

But I find that commonly models return more than predictions, i.e. they may return logits and embeddings, or logits and loss or perhaps a dictionary

It would be nice if you could pass a transform that would apply to the output from the model in order to extract the predictions in the form required by the loss function, i.e.

    x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
    outputs = model(x)
    y_pred = model_transform(outputs)
    loss = loss_fn(y_pred, y)

I have a current work-around - it required a lambda to handle the loss, and output_transform for each metric, i.e.

loss = functools.partial(F.cross_entropy, label_smoothing=params['train']['label_smoothing'])
trainer = create_supervised_trainer(model, 
                                    optimizer, 
                                    # model output is tuple, need to extract first element
                                    loss_fn=lambda output, target : loss(output[0], target), 
                                    prepare_batch=prepare_batch, 
                                    device=params['train']['device'])

evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            # model output is tuple, need to extract first element
                                            'accuracy': Accuracy(output_transform=lambda output : (output[0][0], output[1])),
                                            'loss': Loss(loss, output_transform=lambda output: (output[0][0], output[1]))},
                                        prepare_batch=prepare_batch,
                                        device=params['train']['device'])

It works fine but it was a bit fiddly to make sure I'd done it correctly, it would be a lot easier if I could just pass a single function to create_supervised_trainer i.e. lambda output: output[0]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions