Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train v2] Fix trainer import deserialization when captured within a Ray tasks #50862

Merged
merged 6 commits into from
Feb 24, 2025

Conversation

justinvyu
Copy link
Contributor

Summary

Avoid imports from ray.train.torch in the TorchTrainer definition file since it causes circular import errors when captured in the scope of Ray tasks.

The following script errors at the moment with a circular import error when the TorchTrainer class gets deserialized on the task/actor that captured it. This is also an issue for the Tune + Train integration because the Tune function trainable gets run as a Ray actor task, which would capture the imported class.

# RAY_TRAIN_V2_ENABLED=1

from ray.train.torch import TorchTrainer

@ray.remote
def task():
    TorchTrainer

ray.get(task.remote())
Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/function_manager.py", line 647, in _load_actor_class_from_gcs
    actor_class = pickle.loads(pickled_class)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/train/v2/torch/torch_trainer.py", line 4, in <module>
    from ray.train.torch import TorchConfig
  File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/train/torch/__init__.py", line 28, in <module>
    from ray.train.v2.torch.torch_trainer import TorchTrainer  # noqa: F811
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: cannot import name 'TorchTrainer' from partially initialized module 'ray.train.v2.torch.torch_trainer' (most likely due to a circular import) (/home/ray/anaconda3/lib/python3.11/site-packages/ray/train/v2/torch/torch_trainer.py)

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu justinvyu enabled auto-merge (squash) February 24, 2025 21:56
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Feb 24, 2025
@justinvyu justinvyu merged commit 84d8da3 into ray-project:master Feb 24, 2025
7 checks passed
@justinvyu justinvyu deleted the fix_trainer_ray_serialization branch February 25, 2025 00:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants