Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

take torch.nn.Module model into account when moving to device #3167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

faaany
Copy link
Contributor

@faaany faaany commented Oct 14, 2024

What does this PR do?

Last time in #3133, I introduced the check to only move the model when the model is on "cpu". But this doesn't take the torch.nn.module Model into account, e.g.

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
from torch.utils.data import DataLoader, TensorDataset
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)


# Create a dummy dataset
data = torch.randn(100, 10)
targets = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, targets)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(comm_hook=DDPCommunicationHookType.FP16)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()\

In this case, model doesn't have device as an attribute. So this PR fixes this issue.

@muellerzr

@faaany faaany changed the title bug fix take torch.nn.Module model into account when moving to device Oct 14, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@faaany faaany marked this pull request as ready for review October 15, 2024 06:46
@dvrogozh
Copy link
Contributor

The change #3133 introduced significant regression using Accelerate with PyTorch 2.5 XPU backend (I am not using IPEX). See below log for failing tests. With the reverted change or this PR applied, most of these failures goes away. @muellerzr : can this PR be reviewed and if considered ok merged or #3133 reverted?

  • Without this PR:
FAILED tests/test_accelerator.py::AcceleratorTester::Just test that passing None to accelerator.prepare() works. - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_free_memory_dereferences_prepared_components - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Checks that `_is_accelerator_prepared` is set properly - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_prepared_objects_are_referenced - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object. - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_use_pytorch - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_use_safetensors - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_use_safetensors_tied_weights - AttributeError: 'ModelWithTiedWeights' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_with_hooks_use_pytorch - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_with_hooks_use_safetensors - AttributeError: 'Linear' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=False, num_workers=0, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=False, num_workers=0, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=False, num_workers=2, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=False, num_workers=2, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=True, num_workers=0, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=True, num_workers=0, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=True, num_workers=2, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=False, tied_weights=True, num_workers=2, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=False, num_workers=0, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=False, num_workers=0, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=False, num_workers=2, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=False, num_workers=2, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=True, num_workers=0, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=True, num_workers=0, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=True, num_workers=2, dispatch_batches=True]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_accelerator.py::AcceleratorTester::Test that saving and loading a model with a stateful dataloader returns the same model, [with use_safetensors=True, tied_weights=True, num_workers=2, dispatch_batches=False]
        and that the dataloader's iterator is restored properly. - AttributeError: 'ModelForTest' object has no attribute 'device'
FAILED tests/test_cli.py::
    Test case for verifying the `accelerate launch` CLI operates correctly.
    If a `default_config.yaml` file is located in the cache it will temporarily move it
    for the duration of the tests.
    ::test_accelerate_test - RuntimeError: 'accelerate-launch /home/dvrogozh/git/huggingface/accelerate/src/accelerate/test_utils/scripts/t...
FAILED tests/test_cli.py::
    Test case for checking the output of `accelerate estimate-memory` is correct.

    - Uses `estimate_command` when trying to catch raised errors
    - Uses `gather_data` when just verifying the calculations are correct
    ::test_gated - AssertionError: (<class 'huggingface_hub.utils._errors.GatedRepoError'>, <class 'OSError'>) not raised : Repo ...
FAILED tests/test_grad_sync.py::SyncScheduler::test_gradient_sync_gpu - AttributeError: 'RegressionModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_automatic_loading - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training_checkpoints_relative_path - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training_with_folder - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_map_location - RuntimeError: 'accelerate launch --num_processes=1 --monitor_interval=0.1 /home/dvrogozh/git/huggingface/accel...
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_with_save_limit - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_with_scheduler - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_automatic_loading - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training_checkpoints_relative_path - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training_with_folder - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_map_location - RuntimeError: 'accelerate launch --num_processes=1 --monitor_interval=0.1 /home/dvrogozh/git/huggingface/accel...
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_with_save_limit - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_with_scheduler - AttributeError: 'DummyModel' object has no attribute 'device'
FAILED tests/test_utils.py::UtilsTester::test_dynamo - torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
====================== 44 failed, 196 passed, 140 skipped, 30 warnings in 224.95s (0:03:44) =======================
  • With this PR:
FAILED tests/test_cli.py::
    Test case for checking the output of `accelerate estimate-memory` is correct.

    - Uses `estimate_command` when trying to catch raised errors
    - Uses `gather_data` when just verifying the calculations are correct
    ::test_gated - AssertionError: (<class 'huggingface_hub.utils._errors.GatedRepoError'>, <class 'OSError'>) not raised : Repo ...
FAILED tests/test_utils.py::UtilsTester::test_dynamo - torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
======================= 2 failed, 238 passed, 140 skipped, 30 warnings in 231.72s (0:03:51) =======================

@faaany
Copy link
Contributor Author

faaany commented Oct 27, 2024

thanks @dvrogozh for pasting the UT results. @muellerzr @BenjaminBossan could you take a look and help merge this PR? thx a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants