-
Notifications
You must be signed in to change notification settings - Fork 50
[big merge][student-teacher] Sophiex/dev/ssl losses 1043 #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implemented Identity class TODO: implement EMATeacher
The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark
Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction
It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't look through the actual computations line by line since it seems this copy-paste from the reference code?
| @@ -0,0 +1,304 @@ | |||
| # (C) Copyright 2025 WeatherGenerator contributors. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should go to . They need to be torch.nn.modules because this are NNs, even if they are not necessarily themselves trained. I think ssl_target_processing.py (since you probably still don't like ssl_target_predictors.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want to rename the file simply? I can do that, just want to make sure I understand correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would rename it, and move to weathergen/models/. Sorry, the comment was incomplete. For me, there is key functionality in this file and I wouldn't expect it in a file called utils. For me it's also mainly a model part although I see that it's close to the loss computation.
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| def lossfunc(t, s, temp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is not very descriptive :) Maybe latent_logit_loss.py? JEPA uses MAE (and one could conceivably replace by MSE) which are already implemented in loss.py. Ideally we could reuse what is there.
| Q *= B # the columns must sum to 1 so that Q is an assignment | ||
| return Q.t() | ||
|
|
||
| # def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the stale code? What does it implement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the stale code is there for reference because it needs to go to the loss calculator later
I will do all the clean-up once we are much closer to actually merging :)
|
|
||
| def __init__( | ||
| self, | ||
| patch_out_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to take a dict as arg if we potentially want to implement *TargetProcessing that requires different args.
tjhunter
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some initial comments. looking forward to seeing it in action.
High level comment: the current teacher-student framework wraps the whole model. Do we want that? I always thought it would be applied more locally up to the global assimilation engine. It would simplify future interactions with the diffusion part in the forecasting engine.
src/weathergen/train/trainer.py
Outdated
| rampup_ratio=cf.get("ema_ramp_up_ratio", 0.09), | ||
| is_model_sharded=(cf.with_ddp and cf.with_fsdp), | ||
| ) | ||
| elif cf["training_mode"] == "student-teacher": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small comment: prefer in general cf.get(...) for backward compatibilty
1b96469 to
c5eea85
Compare
…andom and healpix masking. Open issues with _coords_local, centroids and probably other things.
TODO: - Forecast still needs to be adapted - Some more cleanup of variable naming, return values etc
This reverts commit e4519d8.
…iex/dev/ssl-losses-1043
…o merge/overwrite the default config with the --config flag
…iex/dev/ssl-losses-1043
Currently nothing happens in the terminal but I don't know why that is
…rGenerator into sophiex/dev/ssl-losses-1043
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Excited to see the experiments!
* Abstract class for target/aux computation Implemented Identity class TODO: implement EMATeacher * Start implementing the EMA Teacher The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark * adding loss calculator base class * Option for constructing teacher model flexibly * Extract get batch size util function Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction * Fix mismatched dtypes in the target computation It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements * abstract loss calc structure * add abstract method to loss calculator base class * add latent loss class * update loss calc config and rename files * restructure loss modules * add ModelOutput dataclass * NOT WORKING: initial draft for index-based masking. Implemented for random and healpix masking. Open issues with _coords_local, centroids and probably other things. * NOT WORKING: Finished src, target still to be done. * Masking target is working in principle but errors when feeding data to the model. * Working version for ERA5, NPP-ATMS. Problems with SYNOP with empty cell handling * Minor cleanup * Fixed linting * Fixed remaining problems that occured for NPP-ATMS and SYNOP. TODO: - Forecast still needs to be adapted - Some more cleanup of variable naming, return values etc * Enabled support for forecast. Cleaned up some bits and pieces. * mv streams_data declaration under if condition * add weight to loss config, add toy loss class LossPhysicalTwo * Update Abstract Target class based on needs for SSL losses * Removing centroids options for embedding that was unused and should not be used. * Removed unused parameters * fixed trainer for multiple terms in losses_all, still need to fix logging * Inversion of target output ordering to match input one in forcast mode. Unclear how to deal with it with MTM * fix _log_terminal * Changes to prepare_logging to apply index inversion * added file with ModelBatch and SampleMetadata dataclasses * Updating config to working version * fix logging * update ViewMetadata spec * draft changes to allow global local view generation in masker and tokenizer_masking. generate the mask, otherwise using batchify_source and batchify_target as before, with the capacity to remember what mask we have now when it comes to generating the targets. Update to inputs_metadata structure but not put in to practice * draft of training_config in default_config * change view_metadata to dict in ModelInput * NOT WORKING: updating class to handle multiple input steps and improving overall structure * Added basic support for multi-step sources. * Partially enabled correct handling of multiple input steps. * initialize loss as torch tensor with grad * remove level in hist losses dict * rename loss.py to loss_functions.py * rename loss.py to loss_functions.py * return loss with grads seperately to trainer * Added mode and refactored get_sample_data into separate function. * modify log names * add loss_functions.py * Abstract class for target/aux computation Implemented Identity class TODO: implement EMATeacher * Start implementing the EMA Teacher The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark * Option for constructing teacher model flexibly * rm loss_fcts in default config * Comments * Renaming * updated default config training_config to allow student-teacher * added stream id to era5 config * slight restructure of ViewMetadata * basic if statement to yield the student and teacher views * correct imports with new batch.py * Extract get batch size util function Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction * Fix mismatched dtypes in the target computation It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements * Lay groundwork for SSL losses This involves creating stateful classes for each of the losses and the EMATeacher being able to run additional neural network heads for these losses. * Add the SSL Loss Processing classes * Write part of the TargetProcessing forward TODO: create the various teacher head modules and run them. TODO: merge the abstract loss calculator and create the SSL one * Add latent prediction heads to the Model After much consideration I decided to add the latent prediction heads to the Model, because they also need to benefit from exponential moving average of the weights and this gets unnecessarily cumbersome if they are outside the Model. TODO: make JEPA different between student and teacher TODO: use this new structure in EMATeacher * Adapt forward function for latent prediction heads To prevent crazy nesting of model output values we created a ModelOutput Dataclass (akin to how it is done in huggingface), and we run all the latent_prediction heads. * Start piping configs through model, trainer, etc Will need adapting based on the abstract loss calculator Currently is awaiting the streams data branch to check piping of data and configuring this * adding dinov2 notice * Draft Student Teacher Loss Calculator TODO: initialise it and register TODO: weight the loss TODO: route the kwargs TODO: check shapes of tensors * Use infra provided by Abstract Loss Calc Completes config option routing, weighting, and registering TODOs * Run Ruff * Implemented the first draft of the Cropping feature * rough first effort producing globaland local views * update to return 6 tuple from iter in multi-stream-data-sampler, with locals_prepared * Fix class being in the wrong file * Ensure data pipes through model and target This is a DRAFT! This commit assumes that the data augmentations of the stream_data objectsee shmh40/dev/global-local will fit into the Batch data class (trainer.py). The goal was to ensure all data reaches the LossCalculator. Large scale todos: - Pass Metadata about local2global correspondance, etc to the LossCalculator - Upgrade the Model heads to produce the correct dimensions - Verify the Data shapes against DINOv2 Smaller todos: - Ensure teacher params are requires_grad = false - clean up code * Wrap latent state into a dataclass to simply loss calculation later * Progress on computing the loss on correct dims Added the new ViewMetadat and ModelBatch dataclasses that will come from the cropping PR Added LatentState dataclass to compute the latent heads on the correct part of the latent state TODOs: 1. Deal with DINO local and global component 2. Write JEPA loss function in loss.py 3. Test iBOT with actual mask and route student temperature 4. TODOs in the code * Add views.py and run Ruff * Close in on completing DINO loss TODO needs to deal with the tuple part of the DINO loss TODO understand how the ModelBatch input structure affects the loss terms * Revert "rough first effort producing globaland local views" This reverts commit 3fa0033. * Lint code * Fix rebase of loss loss_calculator * created function for _get_student_teacher_sample_data which returns the streams_data of the teacher and multiple streams_datas for the student views. * Not working draft for restructuring * Changes for better student teacher structure * More refactoring * More refactoring and cleanup * More refactoring. Code working again. * Cleaned up parametrization * Changes necessary for spoofing flag per IOReaderData * Changes to have spoofing on a per data reader sample * Moved _get_student_teacher_masks() so that masks are generated for all streams first. * Renaming and minor clean up. * Added basic support for use of ModelBatch class to define rough structure and interface. * linting * Linting * linting * Linting problems but removed unused ViewMetaData dependence * Added required reflexivity between source and target samples to Batch * Added todo * Test for compute time regressions * Prepare for merge * Lint the code * Lint code * Lint * Fix some basic bugs * fix typo in ModelBatch * collect num_source_samples and num_target_samples, add loop over teacher masks hence allowing multiple teacher views, and add source_target_idx to keep track of which student belongs to which teacher * add teacher num_views parameter to config * Re-enabling inversion of targert ordering. * tidy up, remove unused build_stream_views in tokenizer_masking * multiple idxs for each teacher, need to confirm for not student case, and updated ModelBatch for this * add max_num_targets to era5 * add max_num_samples functionality to tokenizer_masking and pass through in multi_stream_data_sampler. coords_per_cell is a bit nasty * Removing spurious code / things that should be merged later * Linting * move build_views_for_stream into masker * Lint code * Rename identity TargetAndAux module * tidy up, remove unused arguments, types * fix masking for NPP-ATMS by correctly selecting final timestep mask and aligning between source and target. working for num_input_steps = 1, broken for > 1, compute_offsets_scatter_embed not working * Make code runnable * updated configs so code runs. Note default config to be overhauled still * Draft for model interface * Make code runnable again Seems slow again * Cleaned up and restructured structure. Not working yet with FSDP * Fixes for FSDP/DDP * Cleaning up, should be merged when needed * Fixes to FSDP * Fix incorrect args for model loading and removing unused code. * Linting * Removing old code * - Fixing inference arg order - Fixing subtle problem with world_size_original that should be taken from config when available * Fixing interface of get_target_aux_calculator * Fixing call to target aux calculator * Fixes to get_target_aux_calculator * Remove stale dataclasses * Fix MAE * very hacky first pass of full masking_strategy_config for the student and teacher views. Much to fix up * instructions for sophie * add SampleMetaData integration and functionality, and update masker to use SampleMetadata. Pass through source_cell_lens and target_coords_idx to student_teacher_batch in iter, and hence pass through to trainer. source_cell_lens and target_coords_idx are now part of Sample, which is itself the components of ModelBatch. To tidy * Prepare for another merge * remove prints, pdb * Save state * add mask to SampleMetaData and add forecast_dt to Sample so it is accessible. Can specify the loss in the default config with student-teacher views * Save state for Seb Currently re-viving the EMATeacher creation Memory is an issue, had to hardcode a smaller latent space * add diffusion forecast option for the data sampling, and with noise_level_rn in the metadata. The Trainer needs to be copied from Sophies branch, currently we only get so far * Attemp to make the iBOT loss work TODO force 1 ibot student view per global view TODO there is a bug with the mask causing a leaf error in pytorch TODO remove all the hardcoded reduced latent space * Linting * Simplified and clarified handling of default target_aux_calcualtor * Linting * Linting * Linting * Linting * Linting * Restoring masking as training_mode in default_config * More linting * Removed duplicate lines due to mergeing * Restored masking as training mode. Not working due to NaN in prediction * Fixed problem in engines introduced in recent commits merging develop. This fixes masking training * remove unused mask generation in diffusion_forecast * restore masking_strategy to random Had placeholder for testing, now back to "random" for masking strategy in the base level of default_config * restore loader_num_workers to 8 * fix indentation of else: assert False in _get_sample msds * Pipe data through all ssl loss fns TODO iBOT head should output class tokens as well as patch tokens TODO remove hardcoded assignments, should be based on config TODO deal with the memory hungriness of it all TODO carefully inspect for bugs * linter warnings * commenting tests * Restructured code so that mask generation and application is cleanly separated * Commit * Update * Fixed uv.lock * Fix for integration test * Re-enabled multi-source training * 1390 - Adapt forward pass of new batch object (ecmwf#1391) * Add to device to ModelBatch, etc & adapt model TODO adapt validate and inference TODO test forecasting and multiple stream because predict changed substantially * Rename view to sample and fix validate * Revert predict function and fix inference * Fix invalid access with mask * Linting * Fixed handling of target_idxs and other minor issues --------- Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int> * Completed migration to new batch class by removing reference to old list of lists * Fixed missing non_blocking=True in to_device() * Removed old comments * Fixed problem with non_blocking=True * Cleaned up comments and return values a bit * Changed args to embedding * Changed core functions to take sample as arg * Changed that model takes sample as input * Fixes for diffusion * Switched to lists of model / target stratgies * Pipe the mask through * Filter student views for the correct loss * Change the masking and msdp to fit student-teacher 1. We ensure that for each target view all the student views are generated 2. We ensure that the target views have their mask applied to the input * Make DINO and iBOT work TODO: use the target mask to reduce memory * Prepare for Model PR introducing class & reg token Thus, right now it breaks. The big question is memory! * Integrate the class and register token PR Done manually because I couldn't figure out how to merge from a fork * Fix iBOT loss with correct PredHead Limitation: iBOT loss needs num_class_tokens to be 1 * Fix JEPA + Lint code * Fix DDP It had unused parameters from the decoders these had to be removed * Running this code + config for JEPA with DDP * Ran JEPA DDP plot with this * Fix FSDP error Potentially a slow down, but I don't understand FSDP well enough for a better fix * Fix conig * Fix validation * Stuck on error taking a break * hot fix to empty tokens_c in encoder when looping over chunks * Revert "hot fix to empty tokens_c in encoder when looping over chunks" This reverts commit e4519d8. * hot fix for local assimilation empty tokens_c * Add class tokens being variable + Fix bugs * Push remaining changes to default config * deepcopy configs so we do not pop weight and lose it for inference * fixed bug in inference with +2 in forecast steps range * add required import to trainer * Update uv.lock * Linting * Record fstep latent states * added two configs, jepa and ibot/dino. Note these configs still try to merge/overwrite the default config with the --config flag * Addres comments from PR review * Prepare SSL losses for logging Currently nothing happens in the terminal but I don't know why that is * Lint * Address PR comments+ upstream changes * Appease the hidden linter * Rename ssl_losses_utils * Add the untracked file * Removing spurious character --------- Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int> Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int> Co-authored-by: Sebastian Hickman <seb.hickman@ecmwf.int> Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int> Co-authored-by: Wael Almikaeel <wael.almikaeel.95@gmail.com> Co-authored-by: Sophie Xhonneux <sxhonneu@santis-ln001.cscs.ch> Co-authored-by: Sebastian Hickman <seb.hickman@gmail.com> Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> Co-authored-by: Julian Kuehnert <Jubeku@users.noreply.github.com>
Description
[DRAFT] PR for introducing the losses for SSL student-teacher latent losses. This PR will rely on both the abstract loss calculator #1178 as well as the abstract target/aux class #1179
The idea is to get early feedback and notice issues my making code more concrete
Issue Number
Closes #1043
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60