Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
and simplicity/ease of use.
"""

from __future__ import annotations

import copy
import inspect
import os
Expand All @@ -24,7 +26,6 @@
from typing import TYPE_CHECKING

from safetensors import safe_open
from safetensors.torch import save_file

from ..utils import (
is_accelerate_available,
Expand All @@ -40,6 +41,7 @@
if is_torch_available():
import torch
import torch.nn as nn
from safetensors.torch import save_file

if is_accelerate_available():
from accelerate import dispatch_model
Expand Down Expand Up @@ -90,7 +92,7 @@ def get_module_size_with_ties(
return module_size_with_ties, tied_module_names, tied_modules


def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
def check_and_set_device_map(device_map: torch.device | int | str | dict | None) -> dict | str | None:
from ..modeling_utils import get_torch_context_manager_or_global_device

# Potentially detect context manager or global device, and use it (only if no device_map was provided)
Expand Down Expand Up @@ -139,8 +141,8 @@ def check_and_set_device_map(device_map: "torch.device | int | str | dict | None


def compute_module_sizes(
model: "PreTrainedModel",
hf_quantizer: "HfQuantizer | None" = None,
model: PreTrainedModel,
hf_quantizer: HfQuantizer | None = None,
buffers_only: bool = False,
only_modules: bool = True,
) -> tuple[dict[str, int], dict[str, int]]:
Expand Down Expand Up @@ -188,7 +190,7 @@ def all_tensors():
return all_module_sizes, leaves_module_sizes


def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: "HfQuantizer | None" = None):
def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: HfQuantizer | None = None):
"""
Compute the total size of buffers in each submodule of a given model.
"""
Expand All @@ -197,10 +199,10 @@ def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: "HfQuantize


def get_balanced_memory(
model: "PreTrainedModel",
model: PreTrainedModel,
max_memory: dict[int | str, int | str] | None = None,
no_split_module_classes: set[str] | None = None,
hf_quantizer: "HfQuantizer | None" = None,
hf_quantizer: HfQuantizer | None = None,
low_zero: bool = False,
):
"""
Expand Down Expand Up @@ -302,10 +304,10 @@ def get_balanced_memory(


def _get_device_map(
model: "PreTrainedModel",
model: PreTrainedModel,
device_map: dict | str | None,
max_memory: dict | None,
hf_quantizer: "HfQuantizer | None",
hf_quantizer: HfQuantizer | None,
) -> dict:
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
Otherwise, we check for any device inconsistencies in the device_map.
Expand Down Expand Up @@ -418,7 +420,7 @@ def get_device(device_map: dict | None, param_name: str, valid_torch_device: boo


def accelerate_disk_offload(
model: "PreTrainedModel",
model: PreTrainedModel,
disk_offload_folder: str | None,
checkpoint_files: list[str] | None,
device_map: dict,
Expand Down Expand Up @@ -497,7 +499,7 @@ def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str |
return offload_index


def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
def load_offloaded_parameter(model: PreTrainedModel, param_name: str) -> torch.Tensor:
"""Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
inside `model`.
This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
Expand Down Expand Up @@ -528,7 +530,7 @@ def _init_infer_auto_device_map(
max_memory: dict[int | str, int | str] | None = None,
no_split_module_classes: set[str] | None = None,
tied_parameters: list[list[str]] | None = None,
hf_quantizer: "HfQuantizer | None" = None,
hf_quantizer: HfQuantizer | None = None,
) -> tuple[
list[int | str],
dict[int | str, int | str],
Expand Down Expand Up @@ -601,7 +603,7 @@ def infer_auto_device_map(
clean_result: bool = True,
offload_buffers: bool = False,
tied_parameters: list[list[str]] | None = None,
hf_quantizer: "HfQuantizer | None" = None,
hf_quantizer: HfQuantizer | None = None,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down