Skip to content

fix: one-variable model #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

fix: one-variable model #478

wants to merge 8 commits into from

Conversation

lzampier
Copy link
Member

Description

The current version of anemoi fails when models with just one prognostic variable are trained because of some .squeeze statements that reduce arrays to scalars. I use .reshape(-1) instead. This seems to fix the issue in my tests and still works with more complex models with multiple variables.

What problem does this change solve?

Enables models with one prognostic variable to be trained in anemoi-train. This is arguably an edge case, which should be enabled nevertheless.

What issue or task does this change relate to?

This change addresses issue #477.

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.

@lzampier lzampier requested a review from anaprietonem August 13, 2025 14:12
@lzampier lzampier added bug Something isn't working training ATS Approval Not Needed No approval needed by ATS labels Aug 13, 2025
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Aug 13, 2025
@lzampier lzampier changed the title Fix/one variable model fix/one variable model Aug 13, 2025
@lzampier lzampier changed the title fix/one variable model fix/one-variable-model Aug 13, 2025
@lzampier lzampier changed the title fix/one-variable-model fix: one-variable model Aug 13, 2025
@lzampier
Copy link
Member Author

I can add that these changes also solve the error in the configuration described in #477.

@@ -229,15 +229,15 @@ def plot_power_spectrum(
grid_pc_lon, grid_pc_lat = np.meshgrid(regular_pc_lon, regular_pc_lat)

for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
yt = y_true[..., variable_idx].squeeze()
yp = y_pred[..., variable_idx].squeeze()
yt = (y_true if y_true.ndim == 1 else y_true[..., variable_idx]).reshape(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzampier - why do you need the if here - could just do y_true[..., variable_idx]).reshape(-1)? and same for x

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes. I was overcautious. Let me test later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, apart from that I think it could be good if we could have a test added here for the print_variable_scaling where we can see if it works fine for one and more than one variable https://github.com/ecmwf/anemoi-core/blob/5a311574b9fb3ab6d610e8e80fc3d91eab061a82/training/src/anemoi/training/losses/utils.py

could you give that a go? I can review later or if you have doubts just drop me a message!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, and yes we need the (y_true if y_true.ndim == 1 else part, otherwise the plotting fails. Regarding the tests, I have done an attempt that you can see in the newest commit. Let me know if you think it is appropriate. The behavior is a bit strange to test, so maybe you have a better idea.

@lzampier lzampier marked this pull request as ready for review August 18, 2025 15:09
@@ -0,0 +1,129 @@
# (C) Copyright 2024 Anemoi contributors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

header should be 2025

loss = _FakeLoss(base)

handler = _ListHandler()
old_level = LOGGER.level
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this for the LOGGER?

# ------------------------------ Tests ---------------------------------------


def test_print_variable_scaling_single_var_flattens_but_not_scalar() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my proposal for the test would be to reuse the tests/unit/train/test_loss_scaling.py -test_variable_loss_scaling_vals

There we can get the data_indices and loss and then call print_variable_scaling(loss, indices) - checking that the test passes with one and many variables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ATS Approval Not Needed No approval needed by ATS bug Something isn't working training
Projects
Status: To be triaged
Development

Successfully merging this pull request may close these issues.

2 participants