Skip to content

Conversation

@csjfwang
Copy link
Contributor

@csjfwang csjfwang commented Dec 10, 2025

Description

Add 2D RoPE to Global and Forecast Engine (on all forecast steps).

Issue Number

Closes #1109

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

wang85 and others added 30 commits July 16, 2025 10:07
@csjfwang
Copy link
Contributor Author

csjfwang commented Dec 10, 2025

@csjfwang csjfwang marked this pull request as ready for review December 10, 2025 15:59
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor things to address. I think it would be worthwhile (since it's easy, and scientifically plausible) to try out 2D rope in every layer in the global and forecast engines. We should also think how to make sure the decoder gets consistent data, i.e. with a positional encoding that does not change with the forecast step (we also didn't think about this for the current global positional encoding).

)
self.pe_global = torch.nn.Parameter(pe, requires_grad=False)

### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ###
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

standard formatting for comment, no need for ### or all caps

self.pe_global = torch.nn.Parameter(pe, requires_grad=False)

### ROPE COORDS (for 2D RoPE when use_2D_rope=True) ###
self.use_2D_rope = cf.use_2D_rope
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use cf.get( 'use_rope_2D, False)` for backward compatability

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter here and config flag should be rope_2D or use_rope_2D. That's easier to parse when we have different rope variants potentially in the future

else:
self.rope_coords = None

### HEALPIX NEIGHBOURS ###
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standard formatting for comment

self.rope_coords.data.copy_(coords_flat)

# Clear pe_global when using 2D RoPE
self.pe_global.data.fill_(0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this at all needed with 2D rope

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because I kept the following code unchanged:

Line 738 - 743:

if self.cf.ae_local_queries_per_cell:
            tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1)
        else:
            tokens_global = (
                self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global
            )

Line 840 - 844:

# recover batch dimension and build global token list
        tokens_global = (
            tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]])
            + model_params.pe_global
        ).flatten(1, 2)



####################################################################################################
def apply_rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove apply_ from name

# Use 2D RoPE instead of traditional global positional encoding
# When True: uses 2D RoPE based on healpix cell coordinates (lat/lon)
# When False: uses traditional pe_global positional encoding
use_2D_rope: False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_rope_2D: False (see below for explanation)


tokens = self.forecast(model_params, tokens, fstep)
# Apply 2D RoPE coords only on the first forecast step
is_first_forecast = fstep == forecast_offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just have fstep == forecast_offset directly in the function argument



####################################################################################################
# Rotary positional embeddings (2D) adapted from Qwen3 & LLama for reuse in WeatherGenerator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add to the copyright notice of the file. See Sophie's branch for examples.

Jifeng Wang added 4 commits December 12, 2025 15:46
- Add copyright attribution for rotate_half() and apply_rotary_pos_emb()
  functions

- Rename apply_rotary_pos_emb_2d() to rotary_pos_emb_2d() for consistency

- Rename config parameter use_2D_rope to rope_2D for better extensibility
  when supporting different RoPE variants in the future
@kctezcan kctezcan mentioned this pull request Dec 30, 2025
4 tasks
@clessig
Copy link
Collaborator

clessig commented Jan 5, 2026

Will be merged as #1540

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Update positional encodings in model to be purely local

2 participants