Skip to content

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Dec 2, 2025

Description

After the stream data refactoring, the new batch object has to be consumed by the model and the target-aux-calculator.
A new function is needed to transfer the batch to device and downstream forward functions have to be adapted.

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed substantially

Issue Number

Closes #1390

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially
self.model(
self.model_params,
(streams_data, source_cell_lens, target_coords_idxs),
(view.streams_data, view.source_cell_lens, view.target_coords_idx),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pass view directly and adapt the model forward here accordingly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view is student/teacher specific. It should be called sample except for code that is student/teacher specific (which is only in some parts of the loss and some specific masking code).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still wondering whether to pass sample instead of (sample.streams_data, sample.source_cell_lens, sample.target_coords_idx)?


# lens for varlen attention
tcs_lens = target_coords_idxs[idx][fstep]
tcs_lens = target_coords_idxs[fstep]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we miss here that target_coords_idxs must be a function of the stream.

@clessig clessig marked this pull request as ready for review December 3, 2025 16:10
@clessig clessig merged commit 402b8de into shmh40/dev/1270-idx-global-local Dec 3, 2025
2 checks passed
@clessig clessig deleted the sophiex/dev/model-forward-adaptation branch December 3, 2025 16:11
clessig added a commit that referenced this pull request Dec 17, 2025
…aining, and pass ModelBatch class (#1283)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
clessig added a commit that referenced this pull request Dec 23, 2025
…1507)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Healpix cropping simple implementation

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Healpix cropping simple implementation with control over the num_samples and overlap + fixing the num_sample bug

* Fixed lint

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

* actually merging develop in properly with merge conflicts resolved

* move imports, move functions

* make functions for different types of cropping rather than if else, remove overlap from select_spatially_contiguous_cells

* restore overlap control in select spatially contiguous cell

* clean up cropping, remove overlap control of crops for now

* make overlap work with source and target masks. working now as random selection. need to think about this if we want to explicitly do overlap of crops.

* restored config somewhat

* lint

* restore config a bit

* remove extra commented out code

* clean up

* remove logging

* invert healpix_cropping masking so aligned with masking and healpix

* lint and updated comments

* remove overlap code, deal with in _get_mask complement, subset etc

* update comments and trues falses masking

* remove legacy argument constraint_keep_mask in the docstring for 2 masking functions

* remove legacy overlap_ratio and overlap from the docstrings in masking

* remove overlap ratio unused arg from example configs for cropping

* make cropping a function, and build shared _prepare_healpix_masking for healpix and healpix_cropping preparation

* rename healpix preparation function

* removed old docstrings and lint

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
clessig added a commit that referenced this pull request Dec 23, 2025
* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* adding loss calculator base class

* Option for constructing teacher model flexibly

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* abstract loss calc structure

* add abstract method to loss calculator base class

* add latent loss class

* update loss calc config and rename files

* restructure loss modules

* add ModelOutput dataclass

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* mv streams_data declaration under if condition

* add weight to loss config, add toy loss class LossPhysicalTwo

* Update Abstract Target class based on needs for SSL losses

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* fixed trainer for multiple terms in losses_all, still need to fix logging

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* fix _log_terminal

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* fix logging

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* initialize loss as torch tensor with grad

* remove level in hist losses dict

* rename loss.py to loss_functions.py

* rename loss.py to loss_functions.py

* return loss with grads seperately to trainer

* Added mode and refactored get_sample_data into separate function.

* modify log names

* add loss_functions.py

* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* Option for constructing teacher model flexibly

* rm loss_fcts in default config

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* Lay groundwork for SSL losses

This involves creating stateful classes for each of the losses and the
EMATeacher being able to run additional neural network heads for these
losses.

* Add the SSL Loss Processing classes

* Write part of the TargetProcessing forward

TODO: create the various teacher head modules and run them.
TODO: merge the abstract loss calculator and create the SSL one

* Add latent prediction heads to the Model

After much consideration I decided to add the latent prediction heads to
the Model, because they also need to benefit from exponential moving
average of the weights and this gets unnecessarily cumbersome if they
are outside the Model.

TODO: make JEPA different between student and teacher
TODO: use this new structure in EMATeacher

* Adapt forward function for latent prediction heads

To prevent crazy nesting of model output values we created a ModelOutput
Dataclass (akin to how it is done in huggingface), and we run all the
latent_prediction heads.

* Start piping configs through model, trainer, etc

Will need adapting based on the abstract loss calculator

Currently is awaiting the streams data branch to check piping of data
and configuring this

* adding dinov2 notice

* Draft Student Teacher Loss Calculator

TODO: initialise it and register
TODO: weight the loss
TODO: route the kwargs
TODO: check shapes of tensors

* Use infra provided by Abstract Loss Calc

Completes config option routing, weighting, and registering TODOs

* Run Ruff

* Implemented the first draft of the Cropping feature

* rough first effort producing globaland local views

* update to return 6 tuple from iter in multi-stream-data-sampler, with locals_prepared

* Fix class being in the wrong file

* Ensure data pipes through model and target

This is a DRAFT!

This commit assumes that the data augmentations of the stream_data
objectsee shmh40/dev/global-local will fit into the Batch data class
(trainer.py).

The goal was to ensure all data reaches the LossCalculator.

Large scale todos:
- Pass Metadata about local2global correspondance, etc to the LossCalculator
- Upgrade the Model heads to produce the correct dimensions
- Verify the Data shapes against DINOv2

Smaller todos:
- Ensure teacher params are requires_grad = false
- clean up code

* Wrap latent state into a dataclass

to simply loss calculation later

* Progress on computing the loss on correct dims

Added the new ViewMetadat and ModelBatch dataclasses that will come from
the cropping PR

Added LatentState dataclass to compute the latent heads on the correct
part of the latent state

TODOs:
1. Deal with DINO local and global component
2. Write JEPA loss function in loss.py
3. Test iBOT with actual mask and route student temperature
4. TODOs in the code

* Add views.py and run Ruff

* Close in on completing DINO loss

TODO needs to deal with the tuple part of the DINO loss
TODO understand how the ModelBatch input structure affects the loss
terms

* Revert "rough first effort producing globaland local views"

This reverts commit 3fa0033.

* Lint code

* Fix rebase of loss loss_calculator

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* Test for compute time regressions

* Prepare for merge

* Lint the code

* Lint code

* Lint

* Fix some basic bugs

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* Removing spurious code / things that should be merged later

* Linting

* move build_views_for_stream into masker

* Lint code

* Rename identity TargetAndAux module

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* Make code runnable

* updated configs so code runs. Note default config to be overhauled still

* Draft for model interface

* Make code runnable again

Seems slow again

* Cleaned up and restructured structure. Not working yet with FSDP

* Fixes for FSDP/DDP

* Cleaning up, should be merged when needed

* Fixes to FSDP

* Fix incorrect args for model loading and removing unused code.

* Linting

* Removing old code

* - Fixing inference arg order
- Fixing subtle problem with world_size_original that should be taken from config when available

* Fixing interface of get_target_aux_calculator

* Fixing call to target aux calculator

* Fixes to get_target_aux_calculator

* Remove stale dataclasses

* Fix MAE

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* Prepare for another merge

* remove prints, pdb

* Save state

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* Save state for Seb

Currently re-viving the EMATeacher creation

Memory is an issue, had to hardcode a smaller latent space

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Attemp to make the iBOT loss work

TODO force 1 ibot student view per global view
TODO there is a bug with the mask causing a leaf error in pytorch
TODO remove all the hardcoded reduced latent space

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* Pipe data through all ssl loss fns

TODO iBOT head should output class tokens as well as patch tokens
TODO remove hardcoded assignments, should be based on config
TODO deal with the memory hungriness of it all
TODO carefully inspect for bugs

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Pipe the mask through

* Filter student views for the correct loss

* Change the masking and msdp to fit student-teacher

1. We ensure that for each target view all the student views are generated
2. We ensure that the target views have their mask applied to the input

* Make DINO and iBOT work

TODO: use the target mask to reduce memory

* Prepare for Model PR introducing class & reg token

Thus, right now it breaks. The big question is memory!

* Integrate the class and register token PR

Done manually because I couldn't figure out how to merge from a fork

* Fix iBOT loss with correct PredHead

Limitation: iBOT loss needs num_class_tokens to be 1

* Fix JEPA + Lint code

* Fix DDP

It had unused parameters from the decoders these had to be removed

* Running this code + config for JEPA with DDP

* Ran JEPA DDP plot with this

* Fix FSDP error

Potentially a slow down, but I don't understand FSDP well enough for a better fix

* Fix conig

* Fix validation

* Stuck on error taking a break

* hot fix to empty tokens_c in encoder when looping over chunks

* Revert "hot fix to empty tokens_c in encoder when looping over chunks"

This reverts commit e4519d8.

* hot fix for local assimilation empty tokens_c

* Add class tokens being variable + Fix bugs

* Push remaining changes to default config

* deepcopy configs so we do not pop weight and lose it for inference

* fixed bug in inference with +2 in forecast steps range

* add required import to trainer

* Update uv.lock

* Linting

* Record fstep latent states

* added two configs, jepa and ibot/dino. Note these configs still try to merge/overwrite the default config with the --config flag

* Addres comments from PR review

* Prepare SSL losses for logging

Currently nothing happens in the terminal but I don't know why that is

* Lint

* Address PR comments+ upstream changes

* Appease the hidden linter

* Rename ssl_losses_utils

* Add the untracked file

* Removing spurious character

---------

Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Sebastian Hickman <seb.hickman@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Sophie Xhonneux <sxhonneu@santis-ln001.cscs.ch>
Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
TillHae pushed a commit to TillHae/WeatherGenerator that referenced this pull request Dec 25, 2025
…aining, and pass ModelBatch class (ecmwf#1283)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
TillHae pushed a commit to TillHae/WeatherGenerator that referenced this pull request Dec 25, 2025
…cmwf#1507)

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* Added mode and refactored get_sample_data into separate function.

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* move build_views_for_stream into masker

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* updated configs so code runs. Note default config to be overhauled still

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* remove prints, pdb

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Updated config

* Changed to per masking strategy loss terms

* Removed old masking options. Still needs to be fully cleaned up

* More robust handling of empty streams

* Fixed incorrect handling of empty target_coords_idx

* Fixed problem when number of model and target samples is different

* Example for config with non-trivial model and target inputs

* Fixed bug in total sample counting

* Re-enabled missing healpix level

* Fixed incorrect handling of masking and student_teacher modes. Follow up fixes required to handle partially filler source/target streams (because source has no target values, eg).

* An encoder formed by embedding + local assimilation + global assimilation (ecmwf#1397)

* initial changes

* more changes

* removed extra print parameters statement

* changed names for backward checkpoint loading

* added encoder. to module names in sharding

* adding encoder. to embed_engine

* added back the conditions for param printong

* lint

* forecast config

* switch back to MTM config

* lint

* Formatting

* Fix source-target matching problem.

* Enabled multiple input steps. Fixed various robustness that arose through this.

This commit also changes the number of forecast steps that are taken. The old loop was at least one step too far. Unclear why the problem occurred now.

* Linting

* Missing update to validation()

* Improved robustness through sanity checking of arguments

* Improved handling of corner cases

* - Fixed incorrect call to get_forecast_steps() in validation
- Fixed interface of target_aux_calculator

* More fixed to validation

* Adding stream_id

* Healpix cropping simple implementation

* Cleaned up ModelOutput class to have proper access functions and a better structure

* Switched to use dict to internally represent streams_datasets

* Improving robustness of interface of ModelOutput class

* Re-enabling model output

* Healpix cropping simple implementation with control over the num_samples and overlap + fixing the num_sample bug

* Fixed lint

* Ruff

* Minor clean-ups and additional comments

* Minor cleanups

* Cleaned up handling of masks and masking metadata

* Current working version of default_config

* Fixed problem with branches with old code and incomplete cleanup

* Updated to test convergence of integration test.

* Updated settings

* Clessig/ypd/dev/1353 add tokens latent state finalization (ecmwf#1452)

* Add LatentState

* Add class and register tokens for LatentState, adjust everything accordingly

* Add option in config file + minor changes

* Add pos.emb. for register tokens + remove class tokens + minor fixes

* Minor fix

* Changed empty to zeros pe_register

* Ruffed

* Clean-up and fixed positional encoding

* Fixing things that got lost during last merge

---------

Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>

* Ruffed

* Adding sanity check for register tokens

* Improved strucutre of LatentState class.

* Improved structure of LatentState

* Fixed problem wiht num_samples > 1

* Improved representation of batch and batch data and more getter functions

* Re-enabled batch_size/num_samples>1. Still some minor problems and cleanup needed but overall program flow working

* Cleanup

* Fixed bug in source-target correspondence with num_samples>1

* Removing incorrect loss scaling

* Cleaned up predict() in model

* Fixed commenting issues

* Fixed problem with freezing of modules for q_cells. Fixed problem when runing in FSDP

* Fixed problem with printing of trainable weights

* Fixed switch for printing of trainable weights

* 1316 Update metrics logging (ecmwf#1412)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

* Remove duplicate to_device

* move loss history into loss calculator

* handle loss_avg and unflatten loss dict

* fixing train_logger

* update validate logging, failing - need to merge data branch

* rm additional log files and log_vals variable, and collapse to single add_logs fct for train and val

* rm comment

* fix validation

* move prepare losses fct to train_logger script, fix terminal logging for val

* fix ctr_loss_fcts normalization; calculate per stream, per lfct average across channels and fsteps for logging

* Fixed linting

* fix bug in emptying history after logging

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* set target_channels dependent on val or train (ecmwf#1471)

* fix targetauxoutput and inference to accept target times and coords. Still lists, to be changed to dicts

* updated targetauxbase so we access target, times and coords with a dict and corresponding changes

* Patch for float in loss dict (shouldn't happen to begin with)

* Fixed handling of per-batch precomputation. Functional for num_sample/batch_size=1 but convergence seems brokens for num_samples>1.

* Reordering

* Linting

* More linting

* More linting

* Removed old, commented code

* Push current progress for inspection (ecmwf#1478)

* Push current progress for inspection

* For Seb

* Delete old code

* Rename according to config

* Fix config

* Push what I have

* Successfully build data

* Fix bugs and lint

* Forgot to revert dataloader workers

* Address PR review comments

- rename student-to-teacher
- extract metadata extraction into a function

* prepare branch for ssl merge, ready for data merge

* Removed mock-up

---------

Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Cleaned up

* actually merging develop in properly with merge conflicts resolved

* move imports, move functions

* make functions for different types of cropping rather than if else, remove overlap from select_spatially_contiguous_cells

* restore overlap control in select spatially contiguous cell

* clean up cropping, remove overlap control of crops for now

* make overlap work with source and target masks. working now as random selection. need to think about this if we want to explicitly do overlap of crops.

* restored config somewhat

* lint

* restore config a bit

* remove extra commented out code

* clean up

* remove logging

* invert healpix_cropping masking so aligned with masking and healpix

* lint and updated comments

* remove overlap code, deal with in _get_mask complement, subset etc

* update comments and trues falses masking

* remove legacy argument constraint_keep_mask in the docstring for 2 masking functions

* remove legacy overlap_ratio and overlap from the docstrings in masking

* remove overlap ratio unused arg from example configs for cropping

* make cropping a function, and build shared _prepare_healpix_masking for healpix and healpix_cropping preparation

* rename healpix preparation function

* removed old docstrings and lint

---------

Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: kctezcan <kctezcan@gmail.com>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln001.cscs.ch>
Co-authored-by: Yura Perugachi Diaz <yperugac@santis-ln002.cscs.ch>
TillHae pushed a commit to TillHae/WeatherGenerator that referenced this pull request Dec 25, 2025
* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* adding loss calculator base class

* Option for constructing teacher model flexibly

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* abstract loss calc structure

* add abstract method to loss calculator base class

* add latent loss class

* update loss calc config and rename files

* restructure loss modules

* add ModelOutput dataclass

* NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.

* NOT WORKING: Finished src, target still to be done.

* Masking target is working in principle but errors when feeding data to the model.

* Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

* Minor cleanup

* Fixed linting

* Fixed remaining problems that occured for NPP-ATMS and SYNOP.
TODO:
- Forecast still needs to be adapted
- Some more cleanup of variable naming, return values etc

* Enabled support for forecast. Cleaned up some bits and pieces.

* mv streams_data declaration under if condition

* add weight to loss config, add toy loss class LossPhysicalTwo

* Update Abstract Target class based on needs for SSL losses

* Removing centroids options for embedding that was unused and should not be used.

* Removed unused parameters

* fixed trainer for multiple terms in losses_all, still need to fix logging

* Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

* fix _log_terminal

* Changes to  prepare_logging to apply index inversion

* added file with ModelBatch and SampleMetadata dataclasses

* Updating config to working version

* fix logging

* update ViewMetadata spec

* draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

* draft of training_config in default_config

* change view_metadata to dict in ModelInput

* NOT WORKING: updating class to handle multiple input steps and improving overall structure

* Added basic support for multi-step sources.

* Partially enabled correct handling of multiple input steps.

* initialize loss as torch tensor with grad

* remove level in hist losses dict

* rename loss.py to loss_functions.py

* rename loss.py to loss_functions.py

* return loss with grads seperately to trainer

* Added mode and refactored get_sample_data into separate function.

* modify log names

* add loss_functions.py

* Abstract class for target/aux computation

Implemented Identity class

TODO: implement EMATeacher

* Start implementing the EMA Teacher

The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark

* Option for constructing teacher model flexibly

* rm loss_fcts in default config

* Comments

* Renaming

* updated default config training_config to allow student-teacher

* added stream id to era5 config

* slight restructure of ViewMetadata

* basic if statement to yield the student and teacher views

* correct imports with new batch.py

* Extract get batch size util function

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction

* Fix mismatched dtypes in the target computation

It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements

* Lay groundwork for SSL losses

This involves creating stateful classes for each of the losses and the
EMATeacher being able to run additional neural network heads for these
losses.

* Add the SSL Loss Processing classes

* Write part of the TargetProcessing forward

TODO: create the various teacher head modules and run them.
TODO: merge the abstract loss calculator and create the SSL one

* Add latent prediction heads to the Model

After much consideration I decided to add the latent prediction heads to
the Model, because they also need to benefit from exponential moving
average of the weights and this gets unnecessarily cumbersome if they
are outside the Model.

TODO: make JEPA different between student and teacher
TODO: use this new structure in EMATeacher

* Adapt forward function for latent prediction heads

To prevent crazy nesting of model output values we created a ModelOutput
Dataclass (akin to how it is done in huggingface), and we run all the
latent_prediction heads.

* Start piping configs through model, trainer, etc

Will need adapting based on the abstract loss calculator

Currently is awaiting the streams data branch to check piping of data
and configuring this

* adding dinov2 notice

* Draft Student Teacher Loss Calculator

TODO: initialise it and register
TODO: weight the loss
TODO: route the kwargs
TODO: check shapes of tensors

* Use infra provided by Abstract Loss Calc

Completes config option routing, weighting, and registering TODOs

* Run Ruff

* Implemented the first draft of the Cropping feature

* rough first effort producing globaland local views

* update to return 6 tuple from iter in multi-stream-data-sampler, with locals_prepared

* Fix class being in the wrong file

* Ensure data pipes through model and target

This is a DRAFT!

This commit assumes that the data augmentations of the stream_data
objectsee shmh40/dev/global-local will fit into the Batch data class
(trainer.py).

The goal was to ensure all data reaches the LossCalculator.

Large scale todos:
- Pass Metadata about local2global correspondance, etc to the LossCalculator
- Upgrade the Model heads to produce the correct dimensions
- Verify the Data shapes against DINOv2

Smaller todos:
- Ensure teacher params are requires_grad = false
- clean up code

* Wrap latent state into a dataclass

to simply loss calculation later

* Progress on computing the loss on correct dims

Added the new ViewMetadat and ModelBatch dataclasses that will come from
the cropping PR

Added LatentState dataclass to compute the latent heads on the correct
part of the latent state

TODOs:
1. Deal with DINO local and global component
2. Write JEPA loss function in loss.py
3. Test iBOT with actual mask and route student temperature
4. TODOs in the code

* Add views.py and run Ruff

* Close in on completing DINO loss

TODO needs to deal with the tuple part of the DINO loss
TODO understand how the ModelBatch input structure affects the loss
terms

* Revert "rough first effort producing globaland local views"

This reverts commit 3fa0033.

* Lint code

* Fix rebase of loss loss_calculator

* created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

* Not working draft for restructuring

* Changes for better student teacher structure

* More refactoring

* More refactoring and cleanup

* More refactoring. Code working again.

* Cleaned up parametrization

* Changes necessary for spoofing flag per IOReaderData

* Changes to have spoofing on a per data reader sample

* Moved _get_student_teacher_masks() so that masks are generated for all streams first.

* Renaming and minor clean up.

* Added basic support for use of ModelBatch class to define rough structure and interface.

* linting

* Linting

* linting

* Linting problems but removed unused ViewMetaData dependence

* Added required reflexivity between source and target samples to Batch

* Added todo

* Test for compute time regressions

* Prepare for merge

* Lint the code

* Lint code

* Lint

* Fix some basic bugs

* fix typo in ModelBatch

* collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

* add teacher num_views parameter to config

* Re-enabling inversion of targert ordering.

* tidy up, remove unused build_stream_views in tokenizer_masking

* multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

* add max_num_targets to era5

* add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

* Removing spurious code / things that should be merged later

* Linting

* move build_views_for_stream into masker

* Lint code

* Rename identity TargetAndAux module

* tidy up, remove unused arguments, types

* fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

* Make code runnable

* updated configs so code runs. Note default config to be overhauled still

* Draft for model interface

* Make code runnable again

Seems slow again

* Cleaned up and restructured structure. Not working yet with FSDP

* Fixes for FSDP/DDP

* Cleaning up, should be merged when needed

* Fixes to FSDP

* Fix incorrect args for model loading and removing unused code.

* Linting

* Removing old code

* - Fixing inference arg order
- Fixing subtle problem with world_size_original that should be taken from config when available

* Fixing interface of get_target_aux_calculator

* Fixing call to target aux calculator

* Fixes to get_target_aux_calculator

* Remove stale dataclasses

* Fix MAE

* very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

* instructions for sophie

* add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

* Prepare for another merge

* remove prints, pdb

* Save state

* add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

* Save state for Seb

Currently re-viving the EMATeacher creation

Memory is an issue, had to hardcode a smaller latent space

* add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

* Attemp to make the iBOT loss work

TODO force 1 ibot student view per global view
TODO there is a bug with the mask causing a leaf error in pytorch
TODO remove all the hardcoded reduced latent space

* Linting

* Simplified and clarified handling of default target_aux_calcualtor

* Linting

* Linting

* Linting

* Linting

* Linting

* Restoring masking as training_mode in default_config

* More linting

* Removed duplicate lines due to mergeing

* Restored masking as training mode. Not working due to NaN in prediction

* Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

* remove unused mask generation in diffusion_forecast

* restore masking_strategy to random

Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

* restore loader_num_workers to 8

* fix indentation of else: assert False in _get_sample msds

* Pipe data through all ssl loss fns

TODO iBOT head should output class tokens as well as patch tokens
TODO remove hardcoded assignments, should be based on config
TODO deal with the memory hungriness of it all
TODO carefully inspect for bugs

* linter warnings

* commenting tests

* Restructured code so that mask generation and application is cleanly separated

* Commit

* Update

* Fixed uv.lock

* Fix for integration test

* Re-enabled multi-source training

* 1390 - Adapt forward pass of new batch object (ecmwf#1391)

* Add to device to ModelBatch, etc & adapt model

TODO adapt validate and inference
TODO test forecasting and multiple stream because predict changed
substantially

* Rename view to sample and fix validate

* Revert predict function and fix inference

* Fix invalid access with mask

* Linting

* Fixed handling of target_idxs and other minor issues

---------

Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

* Completed migration to new batch class by removing reference to old list of lists

* Fixed missing non_blocking=True in to_device()

* Removed old comments

* Fixed problem with non_blocking=True

* Cleaned up comments and return values a bit

* Changed args to embedding

* Changed core functions to take sample as arg

* Changed that model takes sample as input

* Fixes for diffusion

* Switched to lists of model / target stratgies

* Pipe the mask through

* Filter student views for the correct loss

* Change the masking and msdp to fit student-teacher

1. We ensure that for each target view all the student views are generated
2. We ensure that the target views have their mask applied to the input

* Make DINO and iBOT work

TODO: use the target mask to reduce memory

* Prepare for Model PR introducing class & reg token

Thus, right now it breaks. The big question is memory!

* Integrate the class and register token PR

Done manually because I couldn't figure out how to merge from a fork

* Fix iBOT loss with correct PredHead

Limitation: iBOT loss needs num_class_tokens to be 1

* Fix JEPA + Lint code

* Fix DDP

It had unused parameters from the decoders these had to be removed

* Running this code + config for JEPA with DDP

* Ran JEPA DDP plot with this

* Fix FSDP error

Potentially a slow down, but I don't understand FSDP well enough for a better fix

* Fix conig

* Fix validation

* Stuck on error taking a break

* hot fix to empty tokens_c in encoder when looping over chunks

* Revert "hot fix to empty tokens_c in encoder when looping over chunks"

This reverts commit e4519d8.

* hot fix for local assimilation empty tokens_c

* Add class tokens being variable + Fix bugs

* Push remaining changes to default config

* deepcopy configs so we do not pop weight and lose it for inference

* fixed bug in inference with +2 in forecast steps range

* add required import to trainer

* Update uv.lock

* Linting

* Record fstep latent states

* added two configs, jepa and ibot/dino. Note these configs still try to merge/overwrite the default config with the --config flag

* Addres comments from PR review

* Prepare SSL losses for logging

Currently nothing happens in the terminal but I don't know why that is

* Lint

* Address PR comments+ upstream changes

* Appease the hidden linter

* Rename ssl_losses_utils

* Add the untracked file

* Removing spurious character

---------

Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Co-authored-by: Sebastian Hickman <seb.hickman@ecmwf.int>
Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int>
Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com>
Co-authored-by: Sophie Xhonneux <sxhonneu@santis-ln001.cscs.ch>
Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com>
Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com>
Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
moritzhauschulz added a commit to moritzhauschulz/WeatherGenerator that referenced this pull request Jan 7, 2026
commit 9336fe1
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Fri Dec 12 20:10:50 2025 +0100

    requested changes

commit dadde23
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Mon Dec 8 18:54:44 2025 +0100

    remove 1 line

commit c871f9c
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Mon Dec 8 18:16:50 2025 +0100

    remove unnecessary statement

commit e3e46eb
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Mon Dec 8 12:49:03 2025 +0100

    lint

commit 559add7
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Mon Dec 8 12:47:35 2025 +0100

    rename flag and simplify cases

commit f6e1c39
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Thu Dec 4 21:07:42 2025 +0100

    reset config and lint

commit 27cb0c8
Author: moritzhauschulz <moritz.hauschulz@gmail.com>
Date:   Thu Dec 4 20:57:14 2025 +0100

    repeat flag

commit bf17bfe
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 16:53:51 2025 +0100

    Updated config

commit 7745e47
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 16:35:19 2025 +0100

    Switched to lists of model / target stratgies

commit 12bae15
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 15:01:07 2025 +0100

    Fixes for diffusion

commit 9065219
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 13:33:42 2025 +0100

    Changed that model takes sample as input

commit 3f52a8d
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 13:32:14 2025 +0100

    Changed core functions to take sample as arg

commit d36367a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 13:31:55 2025 +0100

    Changed args to embedding

commit b69b743
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 13:30:41 2025 +0100

    Cleaned up comments and return values a bit

commit 59510dd
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 00:01:50 2025 +0100

    Fixed problem with non_blocking=True

commit 69b53a6
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 00:00:42 2025 +0100

    Removed old comments

commit 51754fa
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Dec 4 00:00:20 2025 +0100

    Fixed missing non_blocking=True in to_device()

commit 2cd3971
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 23:56:41 2025 +0100

    Completed migration to new batch class by removing reference to old list of lists

commit 402b8de
Author: Julian Kuehnert <Jubeku@users.noreply.github.com>
Date:   Wed Dec 3 17:11:15 2025 +0100

    1390 - Adapt forward pass of new batch object (ecmwf#1391)

    * Add to device to ModelBatch, etc & adapt model

    TODO adapt validate and inference
    TODO test forecasting and multiple stream because predict changed
    substantially

    * Rename view to sample and fix validate

    * Revert predict function and fix inference

    * Fix invalid access with mask

    * Linting

    * Fixed handling of target_idxs and other minor issues

    ---------

    Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com>
    Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>

commit 9a1a6a9
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 13:12:52 2025 +0100

    Re-enabled multi-source training

commit 3641e1f
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:20:42 2025 +0100

    Fix for integration test

commit 9f5e49c
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:20:25 2025 +0100

    Fixed uv.lock

commit 33d9d8d
Merge: 23e0267 c8a2aad
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:13:05 2025 +0100

    Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit 23e0267
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:11:48 2025 +0100

    Update

commit c8a26d7
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:11:37 2025 +0100

    Commit

commit 2599ec2
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Dec 3 00:10:13 2025 +0100

    Restructured code so that mask generation and application is cleanly separated

commit c8a2aad
Author: Tim Hunter <tim.hunter@ecmwf.int>
Date:   Tue Dec 2 17:06:56 2025 +0100

    commenting tests

commit 2b2c977
Author: Tim Hunter <tim.hunter@ecmwf.int>
Date:   Tue Dec 2 17:03:41 2025 +0100

    linter warnings

commit dc736e5
Merge: 6fe8561 7ff6e0b
Author: Tim Hunter <tim.hunter@ecmwf.int>
Date:   Tue Dec 2 16:48:24 2025 +0100

    merge with dev

commit 6fe8561
Merge: 15b46e9 f136d60
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 14:16:41 2025 +0100

    Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit 15b46e9
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Fri Nov 28 13:30:54 2025 +0100

    fix indentation of else: assert False in _get_sample msds

commit 4281aff
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Fri Nov 28 12:40:24 2025 +0100

    restore loader_num_workers to 8

commit 6ea07e7
Author: Seb Hickman <56727418+shmh40@users.noreply.github.com>
Date:   Fri Nov 28 11:34:41 2025 +0000

    restore masking_strategy to random

    Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config

commit 1a37dd1
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Fri Nov 28 10:31:43 2025 +0100

    remove unused mask generation in diffusion_forecast

commit 657094a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:59:39 2025 +0100

    Fixed problem in engines introduced in recent commits merging develop. This fixes masking training

commit d526dfc
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:37:02 2025 +0100

    Restored masking as training mode. Not working due to NaN in prediction

commit 6289959
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:36:38 2025 +0100

    Removed duplicate lines due to mergeing

commit bc8d23e
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:18:01 2025 +0100

    More linting

commit 47750a5
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:10:09 2025 +0100

    Restoring masking as training_mode in default_config

commit 0db8b62
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:09:41 2025 +0100

    Linting

commit e41a575
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:09:28 2025 +0100

    Linting

commit 03166a2
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:09:10 2025 +0100

    Linting

commit 652500a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:08:53 2025 +0100

    Linting

commit d8998a9
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:08:38 2025 +0100

    Linting

commit 8ef3a4c
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:08:04 2025 +0100

    Simplified and clarified handling of default target_aux_calcualtor

commit 3e4de7a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:07:51 2025 +0100

    Linting

commit 5f803e5
Merge: b47b0fa 0e2801b
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 08:03:02 2025 +0100

    Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit b47b0fa
Merge: 9b702c5 26f7b5b
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 28 07:09:19 2025 +0100

    Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit 26f7b5b
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Thu Nov 27 15:33:22 2025 +0100

    add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far

commit 6d909d6
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Thu Nov 27 11:32:32 2025 +0100

    add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views

commit e0d7346
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Wed Nov 26 14:31:52 2025 +0100

    remove prints, pdb

commit c27156c
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Wed Nov 26 12:35:03 2025 +0100

    add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy

commit 4f8f62b
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Tue Nov 25 18:56:56 2025 +0100

    instructions for sophie

commit fa24fc1
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Tue Nov 25 16:36:52 2025 +0100

    very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up

commit b193a50
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Mon Nov 24 17:13:37 2025 +0100

    updated configs so code runs. Note default config to be overhauled still

commit af9a3c1
Merge: 2905cb0 b452bd2
Author: Sebastian Hickman <seb.hickman@gmail.com>
Date:   Mon Nov 24 16:37:55 2025 +0100

    merge with develop, include trainer idx_inv_rt, merged default_config, rm tokenizer_forecast

commit 2905cb0
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Sat Nov 22 13:59:37 2025 +0000

    fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working

commit b9a60f3
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Fri Nov 21 18:38:40 2025 +0000

    tidy up, remove unused arguments, types

commit ece1dd0
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Fri Nov 21 16:22:27 2025 +0000

    move build_views_for_stream into masker

commit 1a418bf
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Fri Nov 21 12:54:33 2025 +0000

    add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty

commit 91c3d7a
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Fri Nov 21 12:53:31 2025 +0000

    add max_num_targets to era5

commit 647e4b2
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Thu Nov 20 18:31:45 2025 +0000

    multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this

commit 1806ae5
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Thu Nov 20 16:28:30 2025 +0000

    tidy up, remove unused build_stream_views in tokenizer_masking

commit 9b702c5
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 14:34:34 2025 +0100

    Re-enabling inversion of targert ordering.

commit 87ad45f
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Thu Nov 20 13:10:34 2025 +0000

    add teacher num_views parameter to config

commit b34b6da
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Thu Nov 20 13:09:19 2025 +0000

    collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher

commit b2be982
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Thu Nov 20 13:07:47 2025 +0000

    fix typo in ModelBatch

commit d18cf86
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:26:40 2025 +0100

    Added todo

commit e8ccb8d
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:22:26 2025 +0100

    Added required reflexivity between source and target samples to Batch

commit 5d5e999
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:21:31 2025 +0100

    Linting problems but removed unused ViewMetaData dependence

commit 3bca490
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:21:13 2025 +0100

    linting

commit 6a96065
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:20:42 2025 +0100

    Linting

commit c1d32fb
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 20 08:20:21 2025 +0100

    linting

commit 1b1654c
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 22:32:05 2025 +0100

    Added basic support for use of ModelBatch class to define rough structure and interface.

commit 848880b
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 20:06:41 2025 +0100

    Renaming and minor clean up.

commit 6d685c0
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 19:57:46 2025 +0100

    Moved _get_student_teacher_masks() so that masks are generated for all streams first.

commit ed26c02
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 19:57:23 2025 +0100

    Changes to have spoofing on a per data reader sample

commit 9fe94f5
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 19:30:48 2025 +0100

    Changes necessary for spoofing flag per IOReaderData

commit 4613f7a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 17:58:10 2025 +0100

    Cleaned up parametrization

commit 1235aab
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 17:47:40 2025 +0100

    More refactoring. Code working again.

commit 1e70f5c
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 17:09:20 2025 +0100

    More refactoring and cleanup

commit 46147d4
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 17:01:29 2025 +0100

    More refactoring

commit 81cf929
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 15:58:57 2025 +0100

    Changes for better student teacher structure

commit dfc03f2
Merge: a824bfc 31dc658
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 15:58:37 2025 +0100

    Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit a824bfc
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 19 12:23:47 2025 +0100

    Not working draft for restructuring

commit 31dc658
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Wed Nov 19 11:04:29 2025 +0000

    created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views.

commit 2536cec
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:40:26 2025 +0000

    correct imports with new batch.py

commit b3dfa2f
Merge: 11ad4e6 c1580c4
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:36:15 2025 +0000

    merge changes

commit 11ad4e6
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:34:19 2025 +0000

    basic if statement to yield the student and teacher views

commit 36ea287
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:33:53 2025 +0000

    slight restructure of ViewMetadata

commit 66cf9cd
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:33:08 2025 +0000

    added stream id to era5 config

commit 3c26ddc
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Tue Nov 18 17:32:00 2025 +0000

    updated default config training_config to allow student-teacher

commit c1580c4
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 16:30:44 2025 +0100

    Renaming

commit 85fa139
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 16:28:46 2025 +0100

    Comments

commit dd6f85a
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 15:30:22 2025 +0100

    Added mode and refactored get_sample_data into separate function.

commit 668912d
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 13:47:40 2025 +0100

    Partially enabled correct handling of multiple input steps.

commit c3b5c3b
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 12:02:17 2025 +0100

    Added basic support for multi-step sources.

commit ab9eecc
Merge: a934f97 c733280
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 10:00:37 2025 +0100

    Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/WeatherGenerator into shmh40/dev/1270-idx-global-local

commit a934f97
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Tue Nov 18 09:58:19 2025 +0100

    NOT WORKING: updating class to handle multiple input steps and improving overall structure

commit c733280
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Mon Nov 17 18:32:40 2025 +0000

    change view_metadata to dict in ModelInput

commit 7d5c300
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Mon Nov 17 18:22:33 2025 +0000

    draft of training_config in default_config

commit 047b299
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Mon Nov 17 18:19:56 2025 +0000

    draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice

commit 761e263
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Mon Nov 17 18:13:57 2025 +0000

    update ViewMetadata spec

commit 7f3c718
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Mon Nov 17 14:51:01 2025 +0100

    Updating config to working version

commit ae5a2e6
Author: Sebastian Hickman <seb.hickman@ecmwf.int>
Date:   Mon Nov 17 11:54:18 2025 +0000

    added file with ModelBatch and SampleMetadata dataclasses

commit debbb8f
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Mon Nov 17 12:28:07 2025 +0100

    Changes to  prepare_logging to apply index inversion

commit 5d127bf
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Sun Nov 16 17:01:08 2025 +0100

    Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM

commit 8fa544d
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 14 20:43:57 2025 +0100

    Removed unused parameters

commit ce6c735
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 14 16:56:51 2025 +0100

    Removing centroids options for embedding that was unused and should not be used.

commit 0634105
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 14 09:59:13 2025 +0100

    Enabled support for forecast. Cleaned up some bits and pieces.

commit ec38123
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Fri Nov 14 08:27:21 2025 +0100

    Fixed remaining problems that occured for NPP-ATMS and SYNOP.
    TODO:
    - Forecast still needs to be adapted
    - Some more cleanup of variable naming, return values etc

commit db6f285
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 13 23:26:31 2025 +0100

    Fixed linting

commit 9229e48
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 13 23:19:21 2025 +0100

    Minor cleanup

commit a581405
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 13 23:17:29 2025 +0100

    Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling

commit e4a9cc0
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 13 18:58:28 2025 +0100

    Masking target is working in principle but errors when feeding data to the model.

commit 51f437f
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Thu Nov 13 07:04:23 2025 +0100

    NOT WORKING: Finished src, target still to be done.

commit 81bd6eb
Author: Christian Lessig <christian.lessig@ecmwf.int>
Date:   Wed Nov 12 09:38:53 2025 +0100

    NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants