Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add visualization tools #50

Merged
merged 11 commits into from
Oct 15, 2023
Prev Previous commit
Next Next commit
Merge branch 'dev' of github.com:ENSTA-U2IS/torch-uncertainty into vi…
…sualization
  • Loading branch information
o-laurent committed Oct 15, 2023
commit f089499976cb608eccc5544f1a991e3f27afb2ba
16 changes: 16 additions & 0 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,27 @@ def training_step(
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> STEP_OUTPUT:
batch = self.mixup(*batch)
# eventual input repeat is done in the model

if self.is_elbo:
loss = self.criterion(inputs, targets)
else:
logits = self.forward(inputs)
# BCEWithLogitsLoss expects float targets
if self.binary_cls and self.loss == nn.BCEWithLogitsLoss:
logits = logits.squeeze(-1)
targets = targets.float()
# eventual input repeat is done in the model
inputs, targets = self.format_batch_fn(batch)
logits = self.forward(inputs)

# BCEWithLogitsLoss expects float targets
if self.binary_cls and self.loss == nn.BCEWithLogitsLoss:
logits = logits.squeeze(-1)
targets = targets.float()

loss = self.criterion(logits, targets)

self.log("train_loss", loss)
return loss

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.