-
Notifications
You must be signed in to change notification settings - Fork 51
Labels
bugSomething isn't workingSomething isn't working
Description
What happened?
When running training on the latest develop branch using physical jepa config (no changes in the config) using the following command: uv run --offline train --base-config config/config_physical_jepa.yml
Traceback (most recent call last):
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/run_train.py", line 184, in train_with_args
trainer.run(cf, devices)
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 342, in run
self.validate_before_training()
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 380, in validate_before_training
self.validate(-1, cfg, batch_size)
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/train/trainer.py", line 538, in validate
preds = self.ema_model.forward_eval(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/ema.py", line 75, in forward_eval
out = self.ema_model(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/model.py", line 563, in forward
tokens, posteriors = self.encoder(model_params, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py", line 121, in forward
global_tokens, posteriors = self.assimilate_local(model_params, stream_cell_tokens, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py", line 154, in assimilate_local
tokens_global[self.num_register_tokens + self.num_class_tokens :] = tokens_global[
^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (24584) must match the size of tensor b (24576) at non-singleton dimension 0
[12] > /users/walmikae/weathergen/WeatherGenerator/src/weathergen/model/encoder.py(154)assimilate_local()
-> tokens_global[self.num_register_tokens + self.num_class_tokens :] = tokens_global[
The encoder fails to correctly handle tensor operations when register/class tokens are enabled and the combined batch/timestep dimension (rs) is greater than 1.
What are the steps to reproduce the bug?
- Branch: develop (latest)
- HPC: santis & booster
- command: uv run --offline train --base-config config/config_physical_jepa.yml
Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
Type
Projects
Status
No status