Add GenDA diffusion model with sensor conditioning#216
Conversation
|
The failing pre-commit checks appear to come from existing Ruff/docstring violations in unrelated modules (aurora/cafa/anemoi) rather than the new GenDA implementation itself. The new GenDA test passes locally with: python -m pytest tests/test_genda.py Please let me know if you'd like me to additionally help clean up the existing lint/docstring issues. |
|
If you have time, cleaning up those ruff issues would be fantastic, but no worries if not |
for more information, see https://pre-commit.ci
|
hey @jacobbieker Current local status:
The remaining CI Ruff failures appear to come from pre-existing issues in unrelated modules. |
jacobbieker
left a comment
There was a problem hiding this comment.
Thanks for this! Sorry for the delay. Overall, I think its looking good, just a few comments to resolve.
| """Dataloaders and data processing utilities""" | ||
|
|
||
| from .anemoi_dataloader import AnemoiDataset | ||
| from .nnja_ai import SensorDataset |
There was a problem hiding this comment.
I would prefer not to move these around this way.
|
|
||
| def test_genda_forward(): | ||
|
|
||
| model = GenDA( |
There was a problem hiding this comment.
Can you add a test where the lat and lons are not a grid? But scattered around? I want to make sure it handles irregular graphs
| @@ -0,0 +1 @@ | |||
| """Conditioning utilities for GenDA sensor-guided diffusion.""" | |||
There was a problem hiding this comment.
Does something need to be added here? Otherwise this should be removed
|
Added an irregular/scattered point test for GenDA. The test currently verifies that irregular point layouts are rejected with a clear validation error, since the current implementation assumes structured Regular grid forward-pass tests continue to pass locally. |
Pull Request
Description
Implemented a new
GenDAdiffusion-based model with sensor conditioning support using the existing GenCast graph infrastructure.Changes made
Added
GenDAmodel implementation ingraph_weather/models/genda/Added sensor conditioning support using:
sensor_masksensor_valuesAdded classifier-free guidance style conditioning dropout during training
Added guided forward support for conditional inference
Integrated graph processing pipeline using existing GenCast graph builder and processor layers
Added tensor shape validation for conditioning inputs
Added package exports through
__init__.pyAdded unit test for forward-pass validation
Testing
A lightweight test configuration was used to validate:
Fixes #214
How Has This Been Tested?
Test command used:
If your changes affect data processing, have you plotted any changes? i.e. have you done a quick sanity check?
Checklist: