-
Notifications
You must be signed in to change notification settings - Fork 50
Abstract loss calculator #1210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abstract loss calculator #1210
Conversation
config/default_config.yml
Outdated
|
|
||
| loss_fcts: | ||
| - | ||
| - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a discussion how the config should be structured for this
config/default_config.yml
Outdated
| - "mse" | ||
| - 1.0 | ||
| # - | ||
| # - "latent:mse" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sophie-xhonneux : the latent loss function are largely determined by the SSL strategies (with some flexibility, e.g. if MAE or MSE for JEPA) and they are also . The latents returned by the Teacher are a dict with entries like 'DINO' : torch.tensor and iBOT : torch.Tensor. The loss function should somehow come from the SSLTargetProcessors, not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be an option, definitively. My plan was simply for the loss calculator to have a SSLLossCalculator that has a loss function for each DINO, iBOT, JEPA-L1, JEPA-L2. Because at the end of the day we have to specify it somewhere and there is some tensor reshaping and stuff to do.
| return loss, loss_chs | ||
|
|
||
|
|
||
| def mae( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a toy function which should be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's preserve this nice geometric sketch though =)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sketch is from mse. We could potentially generalize MSE to implement any L_p norm. Then we could avoid the code duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will entirely remove the MAE function now as I only used it for testing.
|
There is some work to be done on the logging and how to carry all the terms for logging (maybe in a separate PR). Example log file currently: |
MatKbauer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great progress, I have added a couple comments. We can still decide, whether we want to modularize some repetitive code into functions and implement a latent KL loss now already or postpone it to later.
config/default_config.yml
Outdated
| samples_per_mini_epoch: 4096 | ||
| samples_per_validation: 512 | ||
| samples_per_mini_epoch: 32 | ||
| samples_per_validation: 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert those back to the original settings
| latents = {} | ||
| latents["posteriors"] = posteriors | ||
|
|
||
| return ModelOutput(physical=preds_all, latent=latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To satisfy the dict-definition of physical in the ModelOutput dataclass, we can do something like
physical = {"predictions": preds_all}and pass this dict to the ModelOutput class, i.e., ModelOutput(physical=physical, latent=latents)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the model output always be predictions? Also latents["posteriors"] sounds duplicate although I can see that we can have latents at different stages in the model and want to compute the loss over these. Just most of them will be posteriors in some sense. But no big thing, we can adjust this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, I think we can fix it in the diffusion PR.
|
|
||
| loss_val = loss_fct(target=target, ens=None, mu=pred) | ||
|
|
||
| return loss_val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I stumbled here; can we rename loss_val to just loss or loss_value to prevent confusion with validation?
| Computes loss given predictions and targets and returns values of LossValues dataclass. | ||
| """ | ||
|
|
||
| raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the super.compute_loss() call in the LossLatent.compute_loss() class below (here), should we implement this function here in the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each loss module needs to implement its own compute_loss function, which overwrites the one of the base class. If it doesn't the function of the base class will raise this NotImplmentedError(). So this is on purpose.
| self.loss_fcts = [ | ||
| [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w, name] | ||
| for name, w in loss_fcts | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always the same and can we move it to the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception with "mse_channel_location_weighted" is specific for the physical loss. So I would keep it here for the moment.
| self.loss_unweighted_hist[loss_name].append(losses_all) | ||
| for loss_name, stddev_all in loss_terms.stddev_all.items(): | ||
| self.stdev_unweighted_hist[loss_name].append(stddev_all) | ||
| self.loss_model_hist += [loss.item()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as ~100 lines above for training, isn't it? If so, let's move it into a function..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Will work on this in a separate logging PR.
| "{}".format(st["name"]) | ||
| + f" : {losses_all[st['name']].nanmean():0.4E} \t", | ||
| f"{loss_name}" + f" : {loss_values.nanmean():0.4E} \t", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this is much clearer now!
| log_vals += [loss_values[:, :].nanmean().item()] | ||
| for loss_name, stddev_values in stddev_all.items(): | ||
| metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item() | ||
| log_vals += [stddev_values.nanmean().item()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put this into a function too? Looks like the same is done for training ~50 lines above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, also to be done in the logging PR.
| return loss, loss_chs | ||
|
|
||
|
|
||
| def mae( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's preserve this nice geometric sketch though =)
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert default config back
| latents = {} | ||
| latents["posteriors"] = posteriors | ||
|
|
||
| return ModelOutput(physical=preds_all, latent=latents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the model output always be predictions? Also latents["posteriors"] sounds duplicate although I can see that we can have latents at different stages in the model and want to compute the loss over these. Just most of them will be posteriors in some sense. But no big thing, we can adjust this later.
| return loss, loss_chs | ||
|
|
||
|
|
||
| def mae( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sketch is from mse. We could potentially generalize MSE to implement any L_p norm. Then we could avoid the code duplication.
| """ | ||
| A dataclass to encapsulate the loss components returned by each loss module. | ||
| This provides a structured way to return the primary loss used for optimization, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the documentation is outdated. We do not return the opt-loss any longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the LossValues dataclass we do return the opt loss. Here I think the data class makes sense because the base class predefines that any loss module which is implemented in future has to return a LossValues object which return the opt loss, as well as losses_all and stddev_all.
These are collected by the loss calculator which then returns the derived opt loss to the trainer separately (i.e. not within a dataclass).
| class LossModuleBase: | ||
| def __init__(self): | ||
| """ | ||
| Base class for loss calculators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the base class for LossModules, which correspond to loss terms? The loss calculator is something else.
|
|
||
| # Dynamically load loss functions based on configuration and stage | ||
| self.loss_fcts = [ | ||
| [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mse_channel_location_weighted doesn't make sense for latent loss
| return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) | ||
|
|
||
|
|
||
| class LossPhysicalTwo(LossModuleBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can't we just specify LossPhysical twice in the config?
| @@ -0,0 +1,38 @@ | |||
| # ruff: noqa: T201 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should merge this with the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this placeholder file.
|
|
||
| @dataclasses.dataclass | ||
| class LossValues: | ||
| class LossTerms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LossTerms is ambigious. For me this is what the LossModules are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still wondering if it should be a data class at all as we anyway return the loss separately, i.e. return loss, LossTerms(loss_terms=loss_terms).
In the end, LossTerms are only needed for logging. I would keep it this way for now and update it in the PR which resolved the train logging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but let's put it on the todo list for this PR please so that we don't forget.
| kl = torch.cat([posterior.kl() for posterior in posteriors]) | ||
| loss_values.loss += cf.latent_noise_kl_weight * kl.mean() | ||
| kl = torch.cat([posterior.kl() for posterior in output.latent]) | ||
| loss += cf.latent_noise_kl_weight * kl.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either we write a LossModule for this or we leave and push this soon after.
…evelop/loss_calc_base
|
|
||
| @dataclasses.dataclass | ||
| class LossValues: | ||
| class LossTerms: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but let's put it on the todo list for this PR please so that we don't forget.

Description
Enables generic loss calculation for a given set of predictions-target pairs which can be in latent and/or physical space and part of student-teacher training or diffusion models.
Proposed structure:
Issue Number
Closes #1178
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60