Skip to content

[Bug] Tensor Shape and Indexing Mismatches in encoder.py when rs > 1 #1673

@wael-mika

Description

@wael-mika

What happened?

When running training on the latest develop branch using physical jepa config (no changes in the config) using the following command: uv run --offline train --base-config config/config_physical_jepa.yml

Traceback (most recent call last):
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/run_train.py", line 184, in train_with_args
    trainer.run(cf, devices)
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 342, in run
    self.validate_before_training()
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 380, in validate_before_training
    self.validate(-1, cfg, batch_size)
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 538, in validate
    preds = self.ema_model.forward_eval(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/ema.py", line 75, in forward_eval
    out = self.ema_model(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/model.py", line 563, in forward
    tokens, posteriors = self.encoder(model_params, batch)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py", line 121, in forward
    global_tokens, posteriors = self.assimilate_local(model_params, stream_cell_tokens, batch)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py", line 154, in assimilate_local
    tokens_global[self.num_register_tokens + self.num_class_tokens :] = tokens_global[
                                                                        ^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (24584) must match the size of tensor b (24576) at non-singleton dimension 0
[12] > /users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py(154)assimilate_local()
-> tokens_global[self.num_register_tokens + self.num_class_tokens :] = tokens_global[

The encoder fails to correctly handle tensor operations when register/class tokens are enabled and the combined batch/timestep dimension (rs) is greater than 1.

What are the steps to reproduce the bug?

  • Branch: develop (latest)
  • HPC: santis & booster
  • command: uv run --offline train --base-config config/config_physical_jepa.yml

Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions