Skip to content

Conversation

@clessig
Copy link
Collaborator

@clessig clessig commented Dec 12, 2025

Description

This finalizes #1408 by @yperugachidiaz. So most credit to her.

Issue Number

Closes #1353

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@clessig clessig merged commit 7e7ff8e into shmh40/dev/1270-idx-global-local Dec 12, 2025
1 check passed
clessig added a commit that referenced this pull request Dec 17, 2025
…aining, and pass ModelBatch class (#1283)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
clessig added a commit that referenced this pull request Dec 23, 2025
…1507)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Healpix cropping simple implementation

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Healpix cropping simple implementation with control over the num_samples and overlap + fixing the num_sample bug

* Fixed lint

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

* actually merging develop in properly with merge conflicts resolved

* move imports, move functions

* make functions for different types of cropping rather than if else, remove overlap from select_spatially_contiguous_cells

* restore overlap control in select spatially contiguous cell

* clean up cropping, remove overlap control of crops for now

* make overlap work with source and target masks. working now as random selection. need to think about this if we want to explicitly do overlap of crops.

* restored config somewhat

* lint

* restore config a bit

* remove extra commented out code

* clean up

* remove logging

* invert healpix_cropping masking so aligned with masking and healpix

* lint and updated comments

* remove overlap code, deal with in _get_mask complement, subset etc

* update comments and trues falses masking

* remove legacy argument constraint_keep_mask in the docstring for 2 masking functions

* remove legacy overlap_ratio and overlap from the docstrings in masking

* remove overlap ratio unused arg from example configs for cropping

* make cropping a function, and build shared _prepare_healpix_masking for healpix and healpix_cropping preparation

* rename healpix preparation function

* removed old docstrings and lint

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
TillHae pushed a commit to TillHae/WeatherGenerator that referenced this pull request Dec 25, 2025
…aining, and pass ModelBatch class (ecmwf#1283)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
TillHae pushed a commit to TillHae/WeatherGenerator that referenced this pull request Dec 25, 2025
…cmwf#1507)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Healpix cropping simple implementation

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Healpix cropping simple implementation with control over the num_samples and overlap + fixing the num_sample bug

* Fixed lint

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

* actually merging develop in properly with merge conflicts resolved

* move imports, move functions

* make functions for different types of cropping rather than if else, remove overlap from select_spatially_contiguous_cells

* restore overlap control in select spatially contiguous cell

* clean up cropping, remove overlap control of crops for now

* make overlap work with source and target masks. working now as random selection. need to think about this if we want to explicitly do overlap of crops.

* restored config somewhat

* lint

* restore config a bit

* remove extra commented out code

* clean up

* remove logging

* invert healpix_cropping masking so aligned with masking and healpix

* lint and updated comments

* remove overlap code, deal with in _get_mask complement, subset etc

* update comments and trues falses masking

* remove legacy argument constraint_keep_mask in the docstring for 2 masking functions

* remove legacy overlap_ratio and overlap from the docstrings in masking

* remove overlap ratio unused arg from example configs for cropping

* make cropping a function, and build shared _prepare_healpix_masking for healpix and healpix_cropping preparation

* rename healpix preparation function

* removed old docstrings and lint

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants