Skip to content

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Nov 5, 2025

Description

Enables generic loss calculation for a given set of predictions-target pairs which can be in latent and/or physical space and part of student-teacher training or diffusion models.

Proposed structure:

  • Classes:
    • LossCalculator class: iterates over all special loss classes and returns a combined loss object.
    • LossCalculatorBase class: generic loss calculator structure
    • LossCalculatorPhysical, LossCalculatorLatent, etc.: specific subclasses of LossCalculatorBase
  • DataClasses:
    • LossValues: Predefines the items that are returned by the loss calculator classes
    • InputOutput: Predefines the items/structure of model predictions and targets (can include forecast step logic)

Issue Number

Closes #1178

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@Jubeku Jubeku self-assigned this Nov 5, 2025
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Nov 5, 2025
@clessig clessig self-requested a review November 6, 2025 20:54

loss_fcts:
-
-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a discussion how the config should be structured for this

- "mse"
- 1.0
# -
# - "latent:mse"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux : the latent loss function are largely determined by the SSL strategies (with some flexibility, e.g. if MAE or MSE for JEPA) and they are also . The latents returned by the Teacher are a dict with entries like 'DINO' : torch.tensor and iBOT : torch.Tensor. The loss function should somehow come from the SSLTargetProcessors, not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be an option, definitively. My plan was simply for the loss calculator to have a SSLLossCalculator that has a loss function for each DINO, iBOT, JEPA-L1, JEPA-L2. Because at the end of the day we have to specify it somewhere and there is some tensor reshaping and stuff to do.

@Jubeku
Copy link
Contributor Author

Jubeku commented Nov 17, 2025

_log_termial now prints logs for each loss module:
image

...corresponding to the following toy config:

training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},
                                  LossPhysicalTwo: {weight: 0.3, loss_fcts: [['mse', 1.0]]},
                                  }
                      }

return loss, loss_chs


def mae(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a toy function which should be removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's preserve this nice geometric sketch though =)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sketch is from mse. We could potentially generalize MSE to implement any L_p norm. Then we could avoid the code duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will entirely remove the MAE function now as I only used it for testing.

@Jubeku
Copy link
Contributor Author

Jubeku commented Nov 18, 2025

There is some work to be done on the logging and how to carry all the terms for logging (maybe in a separate PR).

Example log file currently:
v7ad0hi5_train_metrics.json

@Jubeku Jubeku marked this pull request as ready for review November 18, 2025 15:26
Copy link
Contributor

@MatKbauer MatKbauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress, I have added a couple comments. We can still decide, whether we want to modularize some repetitive code into functions and implement a latent KL loss now already or postpone it to later.

samples_per_mini_epoch: 4096
samples_per_validation: 512
samples_per_mini_epoch: 32
samples_per_validation: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert those back to the original settings

latents = {}
latents["posteriors"] = posteriors

return ModelOutput(physical=preds_all, latent=latents)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To satisfy the dict-definition of physical in the ModelOutput dataclass, we can do something like

physical = {"predictions": preds_all}

and pass this dict to the ModelOutput class, i.e., ModelOutput(physical=physical, latent=latents)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the model output always be predictions? Also latents["posteriors"] sounds duplicate although I can see that we can have latents at different stages in the model and want to compute the loss over these. Just most of them will be posteriors in some sense. But no big thing, we can adjust this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, I think we can fix it in the diffusion PR.


loss_val = loss_fct(target=target, ens=None, mu=pred)

return loss_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stumbled here; can we rename loss_val to just loss or loss_value to prevent confusion with validation?

Computes loss given predictions and targets and returns values of LossValues dataclass.
"""

raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the super.compute_loss() call in the LossLatent.compute_loss() class below (here), should we implement this function here in the base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each loss module needs to implement its own compute_loss function, which overwrites the one of the base class. If it doesn't the function of the base class will raise this NotImplmentedError(). So this is on purpose.

self.loss_fcts = [
[getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w, name]
for name, w in loss_fcts
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always the same and can we move it to the base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception with "mse_channel_location_weighted" is specific for the physical loss. So I would keep it here for the moment.

self.loss_unweighted_hist[loss_name].append(losses_all)
for loss_name, stddev_all in loss_terms.stddev_all.items():
self.stdev_unweighted_hist[loss_name].append(stddev_all)
self.loss_model_hist += [loss.item()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as ~100 lines above for training, isn't it? If so, let's move it into a function..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Will work on this in a separate logging PR.

"{}".format(st["name"])
+ f" : {losses_all[st['name']].nanmean():0.4E} \t",
f"{loss_name}" + f" : {loss_values.nanmean():0.4E} \t",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is much clearer now!

log_vals += [loss_values[:, :].nanmean().item()]
for loss_name, stddev_values in stddev_all.items():
metrics[f"loss.{loss_name}.stddev_avg"] = stddev_values.nanmean().item()
log_vals += [stddev_values.nanmean().item()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into a function too? Looks like the same is done for training ~50 lines above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, also to be done in the logging PR.

return loss, loss_chs


def mae(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's preserve this nice geometric sketch though =)

Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert default config back

latents = {}
latents["posteriors"] = posteriors

return ModelOutput(physical=preds_all, latent=latents)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the model output always be predictions? Also latents["posteriors"] sounds duplicate although I can see that we can have latents at different stages in the model and want to compute the loss over these. Just most of them will be posteriors in some sense. But no big thing, we can adjust this later.

return loss, loss_chs


def mae(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sketch is from mse. We could potentially generalize MSE to implement any L_p norm. Then we could avoid the code duplication.

"""
A dataclass to encapsulate the loss components returned by each loss module.
This provides a structured way to return the primary loss used for optimization,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the documentation is outdated. We do not return the opt-loss any longer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the LossValues dataclass we do return the opt loss. Here I think the data class makes sense because the base class predefines that any loss module which is implemented in future has to return a LossValues object which return the opt loss, as well as losses_all and stddev_all.
These are collected by the loss calculator which then returns the derived opt loss to the trainer separately (i.e. not within a dataclass).

class LossModuleBase:
def __init__(self):
"""
Base class for loss calculators.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the base class for LossModules, which correspond to loss terms? The loss calculator is something else.


# Dynamically load loss functions based on configuration and stage
self.loss_fcts = [
[getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mse_channel_location_weighted doesn't make sense for latent loss

return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all)


class LossPhysicalTwo(LossModuleBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove before merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can't we just specify LossPhysical twice in the config?

@@ -0,0 +1,38 @@
# ruff: noqa: T201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should merge this with the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this placeholder file.


@dataclasses.dataclass
class LossValues:
class LossTerms:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LossTerms is ambigious. For me this is what the LossModules are

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still wondering if it should be a data class at all as we anyway return the loss separately, i.e. return loss, LossTerms(loss_terms=loss_terms).
In the end, LossTerms are only needed for logging. I would keep it this way for now and update it in the PR which resolved the train logging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but let's put it on the todo list for this PR please so that we don't forget.

kl = torch.cat([posterior.kl() for posterior in posteriors])
loss_values.loss += cf.latent_noise_kl_weight * kl.mean()
kl = torch.cat([posterior.kl() for posterior in output.latent])
loss += cf.latent_noise_kl_weight * kl.mean()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either we write a LossModule for this or we leave and push this soon after.


@dataclasses.dataclass
class LossValues:
class LossTerms:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but let's put it on the todo list for this PR please so that we don't forget.

@clessig clessig merged commit 94bc7c9 into develop Nov 21, 2025
5 checks passed
@clessig clessig deleted the jk/develop/loss_calc_base branch November 21, 2025 12:09
@Jubeku Jubeku mentioned this pull request Dec 4, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Abstract Loss Calculators

5 participants