-
Notifications
You must be signed in to change notification settings - Fork 49
Add 2D RoPE #1445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 2D RoPE #1445
Conversation
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor things to address. I think it would be worthwhile (since it's easy, and scientifically plausible) to try out 2D rope in every layer in the global and forecast engines. We should also think how to make sure the decoder gets consistent data, i.e. with a positional encoding that does not change with the forecast step (we also didn't think about this for the current global positional encoding).
src/weathergen/model/model.py
Outdated
| ) | ||
| self.pe_global = torch.nn.Parameter(pe, requires_grad=False) | ||
|
|
||
| ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
standard formatting for comment, no need for ### or all caps
src/weathergen/model/model.py
Outdated
| self.pe_global = torch.nn.Parameter(pe, requires_grad=False) | ||
|
|
||
| ### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ### | ||
| self.use_2D_rope = cf.use_2D_rope |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use cf.get( 'use_rope_2D, False)` for backward compatability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter here and config flag should be rope_2D or use_rope_2D. That's easier to parse when we have different rope variants potentially in the future
| else: | ||
| self.rope_coords = None | ||
|
|
||
| ### HEALPIX NEIGHBOURS ### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Standard formatting for comment
| self.rope_coords.data.copy_(coords_flat) | ||
|
|
||
| # Clear pe_global when using 2D RoPE | ||
| self.pe_global.data.fill_(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this at all needed with 2D rope
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because I kept the following code unchanged:
Line 738 - 743:
if self.cf.ae_local_queries_per_cell:
tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1)
else:
tokens_global = (
self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global
)
Line 840 - 844:
# recover batch dimension and build global token list
tokens_global = (
tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]])
+ model_params.pe_global
).flatten(1, 2)
|
|
||
|
|
||
| #################################################################################################### | ||
| def apply_rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove apply_ from name
config/default_config.yml
Outdated
| # Use 2D RoPE instead of traditional global positional encoding | ||
| # When True: uses 2D RoPE based on healpix cell coordinates (lat/lon) | ||
| # When False: uses traditional pe_global positional encoding | ||
| use_2D_rope: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_rope_2D: False (see below for explanation)
src/weathergen/model/model.py
Outdated
|
|
||
| tokens = self.forecast(model_params, tokens, fstep) | ||
| # Apply 2D RoPE coords only on the first forecast step | ||
| is_first_forecast = fstep == forecast_offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just have fstep == forecast_offset directly in the function argument
|
|
||
|
|
||
| #################################################################################################### | ||
| # Rotary positional embeddings (2D) adapted from Qwen3 & LLama for reuse in WeatherGenerator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add to the copyright notice of the file. See Sophie's branch for examples.
- Add copyright attribution for rotate_half() and apply_rotary_pos_emb() functions - Rename apply_rotary_pos_emb_2d() to rotary_pos_emb_2d() for consistency - Rename config parameter use_2D_rope to rope_2D for better extensibility when supporting different RoPE variants in the future
|
Will be merged as #1540 |
Description
Add 2D RoPE to Global and Forecast Engine (on all forecast steps).
Issue Number
Closes #1109
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60