Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
247 commits
Select commit Hold shift + click to select a range
3f1bb7d
Abstract class for target/aux computation
sophie-xhonneux Oct 30, 2025
03ed148
Start implementing the EMA Teacher
sophie-xhonneux Oct 31, 2025
28d9b22
adding loss calculator base class
Jubeku Nov 4, 2025
192beb6
Option for constructing teacher model flexibly
sophie-xhonneux Nov 4, 2025
aac7e29
Extract get batch size util function
sophie-xhonneux Nov 5, 2025
145d18a
Fix mismatched dtypes in the target computation
sophie-xhonneux Nov 5, 2025
f1e7132
abstract loss calc structure
Jubeku Nov 5, 2025
e822e12
add abstract method to loss calculator base class
Jubeku Nov 6, 2025
d24ef48
add latent loss class
Jubeku Nov 6, 2025
c259c20
update loss calc config and rename files
Jubeku Nov 7, 2025
a19ee16
restructure loss modules
Jubeku Nov 11, 2025
bf3e128
add ModelOutput dataclass
Jubeku Nov 11, 2025
81bd6eb
NOT WORKING: initial draft for index-based masking. Implemented for r…
clessig Nov 12, 2025
51f437f
NOT WORKING: Finished src, target still to be done.
clessig Nov 13, 2025
e4a9cc0
Masking target is working in principle but errors when feeding data t…
clessig Nov 13, 2025
a581405
Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty ce…
clessig Nov 13, 2025
9229e48
Minor cleanup
clessig Nov 13, 2025
db6f285
Fixed linting
clessig Nov 13, 2025
ec38123
Fixed remaining problems that occured for NPP-ATMS and SYNOP.
clessig Nov 14, 2025
0634105
Enabled support for forecast. Cleaned up some bits and pieces.
clessig Nov 14, 2025
0fa60db
merge develop
Jubeku Nov 14, 2025
cab9fbe
mv streams_data declaration under if condition
Jubeku Nov 14, 2025
20da555
add weight to loss config, add toy loss class LossPhysicalTwo
Jubeku Nov 14, 2025
391b105
Update Abstract Target class based on needs for SSL losses
sophie-xhonneux Nov 14, 2025
ce6c735
Removing centroids options for embedding that was unused and should n…
clessig Nov 14, 2025
8fa544d
Removed unused parameters
clessig Nov 14, 2025
d7b326b
fixed trainer for multiple terms in losses_all, still need to fix log…
Jubeku Nov 14, 2025
5d127bf
Inversion of target output ordering to match input one in forcast mod…
clessig Nov 16, 2025
3ffdc60
fix _log_terminal
Jubeku Nov 17, 2025
debbb8f
Changes to prepare_logging to apply index inversion
clessig Nov 17, 2025
ae5a2e6
added file with ModelBatch and SampleMetadata dataclasses
shmh40 Nov 17, 2025
7f3c718
Updating config to working version
clessig Nov 17, 2025
beb4d6f
fix logging
Jubeku Nov 17, 2025
761e263
update ViewMetadata spec
shmh40 Nov 17, 2025
047b299
draft changes to allow global local view generation in masker and tok…
shmh40 Nov 17, 2025
7d5c300
draft of training_config in default_config
shmh40 Nov 17, 2025
c733280
change view_metadata to dict in ModelInput
shmh40 Nov 17, 2025
a934f97
NOT WORKING: updating class to handle multiple input steps and improv…
clessig Nov 18, 2025
ab9eecc
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 18, 2025
c3b5c3b
Added basic support for multi-step sources.
clessig Nov 18, 2025
668912d
Partially enabled correct handling of multiple input steps.
clessig Nov 18, 2025
33394ff
initialize loss as torch tensor with grad
Jubeku Nov 18, 2025
bda52d8
remove level in hist losses dict
Jubeku Nov 18, 2025
053dddd
rename loss.py to loss_functions.py
Jubeku Nov 18, 2025
d094ad0
rename loss.py to loss_functions.py
Jubeku Nov 18, 2025
8b4cbef
return loss with grads seperately to trainer
Jubeku Nov 18, 2025
dd6f85a
Added mode and refactored get_sample_data into separate function.
clessig Nov 18, 2025
d0ef572
modify log names
Jubeku Nov 18, 2025
c6805c4
add loss_functions.py
Jubeku Nov 18, 2025
0ccce9e
merge develop
Jubeku Nov 18, 2025
3f379f9
Abstract class for target/aux computation
sophie-xhonneux Oct 30, 2025
7d4734b
Start implementing the EMA Teacher
sophie-xhonneux Oct 31, 2025
901d292
Option for constructing teacher model flexibly
sophie-xhonneux Nov 4, 2025
7ac9e6b
rm loss_fcts in default config
Jubeku Nov 18, 2025
85fa139
Comments
clessig Nov 18, 2025
c1580c4
Renaming
clessig Nov 18, 2025
3c26ddc
updated default config training_config to allow student-teacher
shmh40 Nov 18, 2025
66cf9cd
added stream id to era5 config
shmh40 Nov 18, 2025
36ea287
slight restructure of ViewMetadata
shmh40 Nov 18, 2025
11ad4e6
basic if statement to yield the student and teacher views
shmh40 Nov 18, 2025
b3dfa2f
merge changes
shmh40 Nov 18, 2025
2536cec
correct imports with new batch.py
shmh40 Nov 18, 2025
15e6635
Extract get batch size util function
sophie-xhonneux Nov 19, 2025
1e41df0
Fix mismatched dtypes in the target computation
sophie-xhonneux Nov 5, 2025
106ce11
Lay groundwork for SSL losses
sophie-xhonneux Nov 5, 2025
3a95584
Add the SSL Loss Processing classes
sophie-xhonneux Nov 6, 2025
8e6fe08
Write part of the TargetProcessing forward
sophie-xhonneux Nov 6, 2025
ea3f22b
Add latent prediction heads to the Model
sophie-xhonneux Nov 7, 2025
6fb7fcd
Adapt forward function for latent prediction heads
sophie-xhonneux Nov 7, 2025
149c8cb
Start piping configs through model, trainer, etc
sophie-xhonneux Nov 7, 2025
2afd1ac
adding dinov2 notice
tjhunter Nov 10, 2025
5b725ab
Draft Student Teacher Loss Calculator
sophie-xhonneux Nov 10, 2025
81caf2a
Use infra provided by Abstract Loss Calc
sophie-xhonneux Nov 11, 2025
3af00b1
Run Ruff
sophie-xhonneux Nov 11, 2025
2c78798
Implemented the first draft of the Cropping feature
wael-mika Oct 29, 2025
e66819f
rough first effort producing globaland local views
shmh40 Nov 7, 2025
38f9a93
update to return 6 tuple from iter in multi-stream-data-sampler, with…
shmh40 Nov 7, 2025
594064e
Fix class being in the wrong file
sophie-xhonneux Nov 12, 2025
5191bad
Ensure data pipes through model and target
sophie-xhonneux Nov 12, 2025
b7927c2
Wrap latent state into a dataclass
sophie-xhonneux Nov 14, 2025
c5fec37
Progress on computing the loss on correct dims
sophie-xhonneux Nov 15, 2025
2b5e003
Add views.py and run Ruff
sophie-xhonneux Nov 15, 2025
8b647ee
Close in on completing DINO loss
sophie-xhonneux Nov 17, 2025
f0af4db
Revert "rough first effort producing globaland local views"
sophie-xhonneux Nov 18, 2025
e9b3379
Lint code
sophie-xhonneux Nov 18, 2025
208f4e3
Fix rebase of loss loss_calculator
sophie-xhonneux Nov 19, 2025
31dc658
created function for _get_student_teacher_sample_data which returns t…
shmh40 Nov 19, 2025
a824bfc
Not working draft for restructuring
clessig Nov 19, 2025
dfc03f2
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 19, 2025
81cf929
Changes for better student teacher structure
clessig Nov 19, 2025
46147d4
More refactoring
clessig Nov 19, 2025
1e70f5c
More refactoring and cleanup
clessig Nov 19, 2025
1235aab
More refactoring. Code working again.
clessig Nov 19, 2025
4613f7a
Cleaned up parametrization
clessig Nov 19, 2025
9fe94f5
Changes necessary for spoofing flag per IOReaderData
clessig Nov 19, 2025
ed26c02
Changes to have spoofing on a per data reader sample
clessig Nov 19, 2025
6d685c0
Moved _get_student_teacher_masks() so that masks are generated for al…
clessig Nov 19, 2025
848880b
Renaming and minor clean up.
clessig Nov 19, 2025
1b1654c
Added basic support for use of ModelBatch class to define rough struc…
clessig Nov 19, 2025
c1d32fb
linting
clessig Nov 20, 2025
6a96065
Linting
clessig Nov 20, 2025
3bca490
linting
clessig Nov 20, 2025
5d5e999
Linting problems but removed unused ViewMetaData dependence
clessig Nov 20, 2025
e8ccb8d
Added required reflexivity between source and target samples to Batch
clessig Nov 20, 2025
d18cf86
Added todo
clessig Nov 20, 2025
940e7f5
Test for compute time regressions
sophie-xhonneux Nov 20, 2025
7462a26
Prepare for merge
sophie-xhonneux Nov 20, 2025
798e12b
Lint the code
sophie-xhonneux Nov 20, 2025
0452d2e
Merge remote-tracking branch 'origin/jk/develop/loss_calc_base' into …
sophie-xhonneux Nov 20, 2025
5c30656
Lint code
sophie-xhonneux Nov 20, 2025
25f6b08
Lint
sophie-xhonneux Nov 20, 2025
e002405
Fix some basic bugs
Nov 20, 2025
b2be982
fix typo in ModelBatch
shmh40 Nov 20, 2025
b34b6da
collect num_source_samples and num_target_samples, add loop over teac…
shmh40 Nov 20, 2025
87ad45f
add teacher num_views parameter to config
shmh40 Nov 20, 2025
9b702c5
Re-enabling inversion of targert ordering.
clessig Nov 20, 2025
1806ae5
tidy up, remove unused build_stream_views in tokenizer_masking
shmh40 Nov 20, 2025
647e4b2
multiple idxs for each teacher, need to confirm for not student case,…
shmh40 Nov 20, 2025
91c3d7a
add max_num_targets to era5
shmh40 Nov 21, 2025
1a418bf
add max_num_samples functionality to tokenizer_masking and pass throu…
shmh40 Nov 21, 2025
0ea0181
Removing spurious code / things that should be merged later
clessig Nov 21, 2025
4ae6a64
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Nov 21, 2025
93f66d6
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Nov 21, 2025
47b8297
Linting
clessig Nov 21, 2025
ece1dd0
move build_views_for_stream into masker
shmh40 Nov 21, 2025
65b3a26
Merge branch 'sophiex/dev/abstract-class-teacher-1179' into sophiex/d…
sophie-xhonneux Nov 21, 2025
a6f068a
Lint code
sophie-xhonneux Nov 21, 2025
f54b2ae
Rename identity TargetAndAux module
sophie-xhonneux Nov 21, 2025
b9a60f3
tidy up, remove unused arguments, types
shmh40 Nov 21, 2025
2905cb0
fix masking for NPP-ATMS by correctly selecting final timestep mask a…
shmh40 Nov 22, 2025
2d94d44
Make code runnable
sophie-xhonneux Nov 24, 2025
af9a3c1
merge with develop, include trainer idx_inv_rt, merged default_config…
shmh40 Nov 24, 2025
b193a50
updated configs so code runs. Note default config to be overhauled still
shmh40 Nov 24, 2025
181afc0
Draft for model interface
clessig Nov 25, 2025
18e597a
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Nov 25, 2025
5768d3f
Make code runnable again
sophie-xhonneux Nov 25, 2025
e21d656
Cleaned up and restructured structure. Not working yet with FSDP
clessig Nov 25, 2025
524959c
Fixes for FSDP/DDP
clessig Nov 25, 2025
1b1ffec
Cleaning up, should be merged when needed
clessig Nov 25, 2025
3d28570
Fixes to FSDP
clessig Nov 25, 2025
587eaf5
Fix incorrect args for model loading and removing unused code.
clessig Nov 25, 2025
abb103b
Linting
clessig Nov 25, 2025
330d8be
Removing old code
clessig Nov 25, 2025
79136a3
- Fixing inference arg order
clessig Nov 25, 2025
6d34197
Fixing interface of get_target_aux_calculator
clessig Nov 25, 2025
ca240a8
Fixing call to target aux calculator
clessig Nov 25, 2025
58ba287
Fixes to get_target_aux_calculator
clessig Nov 25, 2025
7c4167f
Remove stale dataclasses
sophie-xhonneux Nov 25, 2025
5bd60bc
Fix MAE
clessig Nov 25, 2025
fa24fc1
very hacky first pass of full masking_strategy_config for the student…
shmh40 Nov 25, 2025
dff96f2
Merge remote-tracking branch 'origin/clessig/dev/abstract-class-teach…
sophie-xhonneux Nov 25, 2025
69d097c
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Nov 25, 2025
4f8f62b
instructions for sophie
shmh40 Nov 25, 2025
c27156c
add SampleMetaData integration and functionality, and update masker t…
shmh40 Nov 26, 2025
8f8389f
Prepare for another merge
sophie-xhonneux Nov 26, 2025
e0d7346
remove prints, pdb
shmh40 Nov 26, 2025
f477271
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Nov 26, 2025
92b184f
Save state
sophie-xhonneux Nov 27, 2025
6d909d6
add mask to SampleMetaData and add forecast_dt to Sample so it is acc…
shmh40 Nov 27, 2025
602a2ee
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Nov 27, 2025
a00fa64
Save state for Seb
sophie-xhonneux Nov 27, 2025
26f7b5b
add diffusion forecast option for the data sampling, and with noise_l…
shmh40 Nov 27, 2025
619b388
Attemp to make the iBOT loss work
sophie-xhonneux Nov 27, 2025
b47b0fa
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Nov 28, 2025
5f803e5
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Nov 28, 2025
3e4de7a
Linting
clessig Nov 28, 2025
8ef3a4c
Simplified and clarified handling of default target_aux_calcualtor
clessig Nov 28, 2025
d8998a9
Linting
clessig Nov 28, 2025
652500a
Linting
clessig Nov 28, 2025
03166a2
Linting
clessig Nov 28, 2025
e41a575
Linting
clessig Nov 28, 2025
0db8b62
Linting
clessig Nov 28, 2025
47750a5
Restoring masking as training_mode in default_config
clessig Nov 28, 2025
bc8d23e
More linting
clessig Nov 28, 2025
6289959
Removed duplicate lines due to mergeing
clessig Nov 28, 2025
d526dfc
Restored masking as training mode. Not working due to NaN in prediction
clessig Nov 28, 2025
657094a
Fixed problem in engines introduced in recent commits merging develop…
clessig Nov 28, 2025
1a37dd1
remove unused mask generation in diffusion_forecast
shmh40 Nov 28, 2025
6ea07e7
restore masking_strategy to random
shmh40 Nov 28, 2025
4281aff
restore loader_num_workers to 8
shmh40 Nov 28, 2025
15b46e9
fix indentation of else: assert False in _get_sample msds
shmh40 Nov 28, 2025
6fe8561
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into shmh…
clessig Nov 28, 2025
9ae22e8
Pipe data through all ssl loss fns
sophie-xhonneux Nov 28, 2025
dc736e5
merge with dev
tjhunter Dec 2, 2025
2b2c977
linter warnings
tjhunter Dec 2, 2025
c8a2aad
commenting tests
tjhunter Dec 2, 2025
2599ec2
Restructured code so that mask generation and application is cleanly …
clessig Dec 2, 2025
c8a26d7
Commit
clessig Dec 2, 2025
23e0267
Update
clessig Dec 2, 2025
33d9d8d
Merge branch 'shmh40/dev/1270-idx-global-local' of github.com:ecmwf/W…
clessig Dec 2, 2025
9f5e49c
Fixed uv.lock
clessig Dec 2, 2025
3641e1f
Fix for integration test
clessig Dec 2, 2025
9a1a6a9
Re-enabled multi-source training
clessig Dec 3, 2025
402b8de
1390 - Adapt forward pass of new batch object (#1391)
Jubeku Dec 3, 2025
d1f5a21
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Dec 3, 2025
6b19da9
Merge remote-tracking branch 'origin/shmh40/dev/1270-idx-global-local…
sophie-xhonneux Dec 3, 2025
2cd3971
Completed migration to new batch class by removing reference to old l…
clessig Dec 3, 2025
51754fa
Fixed missing non_blocking=True in to_device()
clessig Dec 3, 2025
69b53a6
Removed old comments
clessig Dec 3, 2025
59510dd
Fixed problem with non_blocking=True
clessig Dec 3, 2025
b69b743
Cleaned up comments and return values a bit
clessig Dec 4, 2025
d36367a
Changed args to embedding
clessig Dec 4, 2025
3f52a8d
Changed core functions to take sample as arg
clessig Dec 4, 2025
9065219
Changed that model takes sample as input
clessig Dec 4, 2025
12bae15
Fixes for diffusion
clessig Dec 4, 2025
7745e47
Switched to lists of model / target stratgies
clessig Dec 4, 2025
27ac2bd
Pipe the mask through
sophie-xhonneux Dec 4, 2025
b3b80e2
Merge branch 'shmh40/dev/1270-idx-global-local' into sophiex/dev/ssl-…
sophie-xhonneux Dec 4, 2025
71bc79a
Filter student views for the correct loss
sophie-xhonneux Dec 4, 2025
052b012
Change the masking and msdp to fit student-teacher
sophie-xhonneux Dec 6, 2025
0ddf3f5
Make DINO and iBOT work
sophie-xhonneux Dec 6, 2025
e10855b
Prepare for Model PR introducing class & reg token
sophie-xhonneux Dec 6, 2025
75a4ab2
Integrate the class and register token PR
sophie-xhonneux Dec 7, 2025
63ae111
Fix iBOT loss with correct PredHead
sophie-xhonneux Dec 7, 2025
70c5808
Fix JEPA + Lint code
sophie-xhonneux Dec 7, 2025
995f4c0
Fix DDP
sophie-xhonneux Dec 15, 2025
3dcbe59
Running this code + config for JEPA with DDP
sophie-xhonneux Dec 16, 2025
fa0c5c1
Ran JEPA DDP plot with this
sophie-xhonneux Dec 17, 2025
25a2c0e
Fix FSDP error
sophie-xhonneux Dec 17, 2025
01ce0ba
Fix conig
sophie-xhonneux Dec 17, 2025
30a9201
Merge branch 'develop' into sophiex/dev/ssl-losses-1043
sophie-xhonneux Dec 18, 2025
1088e4b
Fix validation
sophie-xhonneux Dec 18, 2025
64a6aed
Stuck on error taking a break
sophie-xhonneux Dec 18, 2025
e4519d8
hot fix to empty tokens_c in encoder when looping over chunks
shmh40 Dec 18, 2025
670784f
Revert "hot fix to empty tokens_c in encoder when looping over chunks"
shmh40 Dec 18, 2025
0212620
hot fix for local assimilation empty tokens_c
shmh40 Dec 18, 2025
e89b781
Add class tokens being variable + Fix bugs
sophie-xhonneux Dec 19, 2025
466b24a
Push remaining changes to default config
sophie-xhonneux Dec 19, 2025
b0e4959
deepcopy configs so we do not pop weight and lose it for inference
shmh40 Dec 19, 2025
697231d
fixed bug in inference with +2 in forecast steps range
shmh40 Dec 19, 2025
e3846f6
add required import to trainer
shmh40 Dec 19, 2025
ffe2bc0
Merge branch 'develop' into sophiex/dev/ssl-losses-1043
sophie-xhonneux Dec 19, 2025
f869824
Update uv.lock
sophie-xhonneux Dec 19, 2025
414cf36
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Dec 23, 2025
7dc52a9
Linting
clessig Dec 23, 2025
8a41e68
Record fstep latent states
sophie-xhonneux Dec 23, 2025
c7fc4df
added two configs, jepa and ibot/dino. Note these configs still try t…
shmh40 Dec 23, 2025
e1f59ab
Merge branch 'develop' of github.com:ecmwf/WeatherGenerator into soph…
clessig Dec 23, 2025
7799138
Addres comments from PR review
sophie-xhonneux Dec 23, 2025
108ddcb
Prepare SSL losses for logging
sophie-xhonneux Dec 23, 2025
62b660b
Merge branch 'sophiex/dev/ssl-losses-1043' of github.com:ecmwf/Weathe…
clessig Dec 23, 2025
0944593
Lint
sophie-xhonneux Dec 23, 2025
99f74b5
Address PR comments+ upstream changes
sophie-xhonneux Dec 23, 2025
1ea6ff3
Appease the hidden linter
sophie-xhonneux Dec 23, 2025
c585dc6
Rename ssl_losses_utils
sophie-xhonneux Dec 23, 2025
b4b17b6
Add the untracked file
sophie-xhonneux Dec 23, 2025
ff2c7aa
Removing spurious character
clessig Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
This project includes code derived from project "DINOv2: Learning Robust Visual Features without Supervision",
originally developed by Meta Platforms, Inc. and affiliates,
licensed under the Apache License, Version 2.0.

Original NOTICE from project DINOv2
--------------------------------------

N/A


90 changes: 84 additions & 6 deletions config/default_config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
streams_directory: "./config/streams/era5_1deg/"
# streams_directory: "./config/streams/era5_nppatms_synop/"

embed_orientation: "channels"
embed_unembed_mode: "block"
Expand Down Expand Up @@ -46,6 +45,8 @@ pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
Expand Down Expand Up @@ -93,7 +94,7 @@ validate_with_ema: True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3


### Example validation and training config for mask token modelling in physical space
validation_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},}}
# Student-teacher configuration (only used when training_mode == "student_teacher")
# TODO: adapt so that the masking or forecast config entry also sits here
Expand All @@ -111,6 +112,36 @@ training_config:
relationship: "complement" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
num_steps_input: 1

# ### Example validation and training config for student-teacher with JEPA
# validation_config:
# losses:
# LossLatentSSLStudentTeacher: {
# "weight": 1.0,
# "JEPA": {'weight': 5, "loss_extra_args": {}, "out_dim": 2048} }
# ### Student-teacher configuration (only used when training_mode == "student_teacher")
# training_config:
# # when this is "masking", we are basically only using the model_input subconfig
# training_mode: "student_teacher" # "masking", "student_teacher", "forecast"
# target_and_aux_calc: "EMATeacher"
# losses :
# LossLatentSSLStudentTeacher: {
# "weight": 1.0,
# "JEPA": {'weight': 5, "loss_extra_args": {}, "out_dim": 2048} }
# model_input:
# - masking_strategy: "random" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher
# num_samples: 1 # if student-teacher, the number of local (student) views to generate
# masking_strategy_config : { diffusion_rn : False, rate : 0.4 }
# # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view.
# relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
# loss : jepa
# rate_sampling: False # randomly sample the rate per batch
#
# target_input:
# - masking_strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix"
# masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 0 }
# num_samples: 1 # number of teacher views to generate
# rate_sampling: False # randomly sample the rate per batch


# - masking_strategy: "random"
# num_samples: 2 # if student-teacher, the number of local (student) views to generate
Expand Down Expand Up @@ -217,12 +248,60 @@ training_config:
# relationship: "independent"
# num_steps_input: 1

# ### Example validation and training config for student-teacher with iBOT and DINO
# validation_config:
# losses:
# LossLatentSSLStudentTeacher: {
# "weight": 1.0,
# "iBOT": {'weight': 0.75, "loss_extra_args": { "student_temp": 0.1,},"out_dim": 16384, "teacher_temp": 0.1,
# "teacher_style": "softmax_center", "center_momentum": 0.9},
# "DINO": {'weight': 0.25, "loss_extra_args": { "student_temp": 0.1,}, "out_dim": 16384, "teacher_temp": 0.1,
# "teacher_style": "softmax_center", "center_momentum": 0.9},
# }
#
#
# ### Student-teacher configuration (only used when training_mode == "student_teacher")
# training_config:
# # when this is "masking", we are basically only using the model_input subconfig
# training_mode: "student_teacher" # "masking", "student_teacher", "forecast"
# target_and_aux_calc: "EMATeacher"
# losses :
# LossLatentSSLStudentTeacher: {
# "weight": 1.0,
# "iBOT": {'weight': 0.75, "loss_extra_args": { "student_temp": 0.1,},"out_dim": 4096, # 16384,
# "teacher_temp": 0.1, "teacher_style": "softmax_center", "center_momentum": 0.9},
# "DINO": {'weight': 0.25, "loss_extra_args": { "student_temp": 0.1,}, "out_dim": 4096, # 16384,
# "teacher_temp": 0.1, "teacher_style": "softmax_center", "center_momentum": 0.9},
# }
# model_input:
# - masking_strategy: "random" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher
# num_samples: 1 # if student-teacher, the number of local (student) views to generate
# masking_strategy_config : { diffusion_rn : False, rate : 0.4 }
# relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
# rate_sampling: False # randomly sample the rate per batch
# loss : ibot
# - masking_strategy: "healpix"
# num_samples: 2 # if student-teacher, the number of local (student) views to generate
# masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 1 }
# relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
# rate_sampling: False # randomly sample the rate per batch
# loss : dino
# - masking_strategy: "healpix"
# num_samples: 1 # if student-teacher, the number of local (student) views to generate
# masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 1 }
# relationship: "identity" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
# rate_sampling: False # randomly sample the rate per batch
# loss : dino
#
# target_input:
# - masking_strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix"
# masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 0 }
# num_samples: 2 # number of teacher views to generate
# rate_sampling: False # randomly sample the rate per batch


num_register_tokens: 0

num_mini_epochs: 32
samples_per_mini_epoch: 4096
samples_per_mini_epoch: 4096 # 250000 for student-teacher because validation is meaningless
samples_per_validation: 512

shuffle: True
Expand Down Expand Up @@ -271,7 +350,6 @@ train_log_freq:
metrics: 20
checkpoint: 250


# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
Expand Down
216 changes: 216 additions & 0 deletions config/default_config_dino.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
streams_directory: "./config/streams/era5_1deg/"

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 2
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta: 00:00:00
forecast_steps: 0
forecast_policy: null
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
forecast_with_step_conditioning: True # False
fe_num_blocks: 0
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# whether to track the exponential moving average of weights for validation
validate_with_ema: True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

### Example validation and training config for student-teacher with iBOT and DINO
validation_config:
losses:
LossLatentSSLStudentTeacher: {
"weight": 1.0,
"iBOT": {'weight': 0.75, "loss_extra_args": { "student_temp": 0.1,},"out_dim": 16384, "teacher_temp": 0.1,
"teacher_style": "softmax_center", "center_momentum": 0.9},
"DINO": {'weight': 0.25, "loss_extra_args": { "student_temp": 0.1,}, "out_dim": 16384, "teacher_temp": 0.1,
"teacher_style": "softmax_center", "center_momentum": 0.9},
}

### Student-teacher configuration (only used when training_mode == "student_teacher")
training_config:
# when this is "masking", we are basically only using the model_input subconfig
training_mode: "student_teacher" # "masking", "student_teacher", "forecast"
target_and_aux_calc: "EMATeacher"
losses :
LossLatentSSLStudentTeacher: {
"weight": 1.0,
"iBOT": {'weight': 0.75, "loss_extra_args": { "student_temp": 0.1,},"out_dim": 4096, # 16384,
"teacher_temp": 0.1, "teacher_style": "softmax_center", "center_momentum": 0.9},
"DINO": {'weight': 0.25, "loss_extra_args": { "student_temp": 0.1,}, "out_dim": 4096, # 16384,
"teacher_temp": 0.1, "teacher_style": "softmax_center", "center_momentum": 0.9},
}
model_input:
- masking_strategy: "random" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher
num_samples: 1 # if student-teacher, the number of local (student) views to generate
masking_strategy_config : { diffusion_rn : False, rate : 0.4 }
relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
rate_sampling: False # randomly sample the rate per batch
loss : ibot
- masking_strategy: "healpix"
num_samples: 2 # if student-teacher, the number of local (student) views to generate
masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 1 }
relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
rate_sampling: False # randomly sample the rate per batch
loss : dino
- masking_strategy: "healpix"
num_samples: 1 # if student-teacher, the number of local (student) views to generate
masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 1 }
relationship: "identity" # "independent", "subset", "disjoint". Relationship of student views to teacher view.
rate_sampling: False # randomly sample the rate per batch
loss : dino

target_input:
- masking_strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix"
masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 0 }
num_samples: 2 # number of teacher views to generate
rate_sampling: False # randomly sample the rate per batch

num_mini_epochs: 32
samples_per_mini_epoch: 4096 # 250000 for student-teacher because validation is meaningless
samples_per_validation: 512

shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "constant"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"
log_grad_norms: False

start_date: 1979-01-01T00:00
end_date: 2022-12-31T00:00
start_date_val: 2023-10-01T00:00
end_date_val: 2023-12-31T00:00
time_window_step: 06:00:00
time_window_len: 06:00:00
input_window_steps: 1

val_initial: False #True

loader_num_workers: 12
log_validation: 0
streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# The period to log in the training loop (in number of batch steps)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250

# Tags for experiment tracking
# These tags will be logged in MLFlow along with completed runs for train, eval, val
# The tags are free-form, with the following rules:
# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries
# - tags should not duplicate existing config entries.
# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags
# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future)
wgtags:
# The name of the organization of the person running the experiment.
# This may be autofilled in the future. Expected values are lowercase strings of
# the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List
# e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience"
org: None
# The name of the experiment. This is a distinctive codename for the experiment campaign being run.
# This is expected to be the primary tag for comparing experiments in MLFlow.
# Expected values are lowercase strings with no spaces, just underscores:
# Examples: "rollout_ablation_grid"
exp: None
# *** Experiment-specific tags ***
grid: None
Loading