Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State dictionary retrieval from offloaded modules #2619

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,49 @@ def get_state_dict_offloaded_model(model: nn.Module):
return state_dict


def get_state_dict_from_offload(
module: nn.Module,
module_name: str,
state_dict: Dict[str, Union[str, torch.tensor]],
device_to_put_offload: Union[int, str, torch.device] = "cpu",
):
"""
Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defualts
to cpu).

Args:
module: (`torch.nn.Module`):
The module we want to retrieve a state dictionary from
module_name: (`str`):
The name of the module of interest
state_dict (`Dict[str, Union[int, str, torch.device]]`):
Dictionary of {module names: parameters}
device_to_put_offload (`Union[int, str, torch.device]`):
Device to load offloaded parameters into, defaults to the cpu.
"""
from ..hooks import AlignDevicesHook

root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
preforward = False
if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload:
# assign the device to which the offloaded parameters will be sent
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = device_to_put_offload
module._hf_hook.pre_forward(module)
preforward = True

for m_key in module.state_dict():
params = module.state_dict()[m_key]
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

if preforward:
module._hf_hook.post_forward(module, torch.tensor([]))
module._hf_hook.execution_device = original_device

return state_dict


def load_checkpoint_in_model(
model: nn.Module,
checkpoint: Union[str, os.PathLike],
Expand Down
29 changes: 28 additions & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from accelerate.test_utils import require_bnb, require_multi_device, require_non_cpu, slow, torch_device
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
from accelerate.utils import patch_environment
from accelerate.utils.modeling import load_checkpoint_in_model
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model


def create_components():
Expand Down Expand Up @@ -269,6 +269,33 @@ def test_save_model_offload(self, use_safetensors):
output = model(inputs)
assert torch.allclose(expected, output, atol=1e-5)

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_get_state_dict_from_offload(self, use_safetensors):
accelerator = Accelerator()

device_map = {"linear1": "cpu", "batchnorm": "disk", "linear2": "disk"}
model = ModelForTest()
offloaded_layer_weight = model.linear2.weight
with tempfile.TemporaryDirectory() as tmp_dir:
accelerator.save_model(model, tmp_dir, safe_serialization=use_safetensors)
# load model with offloaded layers
load_checkpoint_and_dispatch(model, tmp_dir, device_map=device_map, offload_folder=tmp_dir)
cpu_onloaded_layer = get_state_dict_from_offload(
model.linear2, "linear2.weight", {"linear2.weight": ""}, device_to_put_offload="cpu"
)
cuda_onloaded_layer = get_state_dict_from_offload(
model.linear2, "linear2.weight", {"linear2.weight": ""}, device_to_put_offload=0
)
cpu_onloaded_layer_weight = cpu_onloaded_layer["linear2.weight"]
cuda_onloaded_layer_weight = cuda_onloaded_layer["linear2.weight"]

assert torch.allclose(offloaded_layer_weight, cpu_onloaded_layer_weight)
assert torch.allclose(
offloaded_layer_weight, cuda_onloaded_layer_weight.to("cpu")
) # must be on the same device for torch.allclose()
assert cpu_onloaded_layer_weight.device.type == "cpu"
assert cuda_onloaded_layer_weight.device.type == "cuda"

@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
def test_save_load_model_with_hooks(self, use_safetensors):
accelerator = Accelerator()
Expand Down
Loading