Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 0 additions & 109 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import re
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING

from safetensors import safe_open
Expand Down Expand Up @@ -55,114 +54,6 @@
logger = logging.get_logger(__name__)


@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
empty model. Useful when just initializing the model would blow the available RAM.

Args:
include_buffers (`bool`, *optional*):
Whether or not to also put all buffers on the meta device while initializing.

Example:

```python
import torch.nn as nn
from accelerate import init_empty_weights

# Initialize a model with 100 billions parameters in no time and without using any RAM.
with init_empty_weights():
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
```

<Tip warning={true}>

Any model created under this context manager has no weights. As such you can't do something like
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
called.

</Tip>
"""
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
yield f


@contextmanager
def init_on_device(device: "torch.device", include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the specified device.

Args:
device (`torch.device`):
Device to initialize all parameters on.
include_buffers (`bool`, *optional*):
Whether or not to also put all buffers on the meta device while initializing.

Example:

```python
import torch.nn as nn
from accelerate import init_on_device

with init_on_device(device=torch.device("cuda")):
tst = nn.Linear(100, 100) # on `cuda` device
```
"""
if include_buffers:
with device:
yield
return

old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer

def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)

# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}

def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)

return wrapper

try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch:
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)


def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
from ..modeling_utils import get_torch_context_manager_or_global_device

Expand Down
8 changes: 3 additions & 5 deletions src/transformers/integrations/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
"AQLM (Additive Quantization of Language Model) integration file"

from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn

logger = logging.get_logger(__name__)
Expand All @@ -46,7 +44,7 @@ def replace_with_aqlm_linear(model, modules_to_not_convert: list[str] | None = N
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
new_module = QuantizedLinear(
module.in_features,
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@
from typing import Optional, Union

from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -97,7 +94,7 @@ def replace_with_awq_linear(
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
new_module = target_cls(
bits=quantization_config.bits,
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/integrations/bitnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -334,7 +331,7 @@ def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None =
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
if quantization_config and quantization_config.linear_class == "autobitlinear":
new_module = AutoBitLinear(
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -181,7 +180,7 @@ def replace_with_bnb_linear(
if not should_convert_module(module_name, modules_to_not_convert):
continue
new_module = None
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, (nn.Linear, Conv1D)):
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
Expand Down Expand Up @@ -293,7 +292,7 @@ def dequantize_and_replace(model, quantization_config=None, dtype=None):
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
for module_name, module in model.named_modules():
if isinstance(module, target_cls):
with init_empty_weights():
with torch.device("meta"):
bias = getattr(module, "bias", None)
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
state = module.state if quant_method == "llm_int8" else None
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/integrations/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# limitations under the License.
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_torch_available, logging


if is_torch_available():
import torch
import torch.nn as nn

if is_accelerate_available():
from accelerate import init_empty_weights

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -108,7 +106,7 @@ def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = N
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
new_module = EetqLinear(
module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
from ..utils import is_torch_accelerator_available, is_torch_available, logging


if is_torch_available():
Expand All @@ -25,9 +25,6 @@
import triton.language as tl
from torch.nn import functional as F

if is_accelerate_available():
from accelerate import init_empty_weights


logger = logging.get_logger(__name__)
try:
Expand Down Expand Up @@ -618,7 +615,7 @@ def replace_with_fp8_linear(
# we need this to correctly materialize the weights during quantization
module_kwargs = {} if pre_quantized else {"dtype": None}
new_module = None
with init_empty_weights():
with torch.device("meta"):
if module_name.endswith(".experts"):
new_module = FP8Expert(
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/integrations/higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@
from math import sqrt

from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging
from ..utils import is_flute_available, is_hadamard_available, is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn
Expand Down Expand Up @@ -569,7 +566,7 @@ def replace_with_higgs_linear(model, modules_to_not_convert: list[str] | None =
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
new_module = HiggsLinear(
module.in_features,
Expand Down
12 changes: 3 additions & 9 deletions src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from ..utils import is_torch_available, is_torch_xpu_available, logging


if is_torch_available():
import torch
from torch import nn
from contextlib import contextmanager
from typing import Optional

from ..core_model_loading import ConversionOps


if is_accelerate_available():
from accelerate import init_empty_weights

from contextlib import contextmanager

from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module


Expand Down Expand Up @@ -620,7 +614,7 @@ def replace_with_mxfp4_linear(model, quantization_config=None, modules_to_not_co
if not should_convert_module(module_name, modules_to_not_convert):
continue
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
with init_empty_weights():
with torch.device("meta"):
model.set_submodule(module_name, Mxfp4GptOssExperts(model.config))
has_been_replaced = True
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/integrations/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def replace_with_quanto_layers(
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
converted.
"""
from accelerate import init_empty_weights
from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8

w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
Expand All @@ -83,7 +82,7 @@ def replace_with_quanto_layers(
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
new_module = None
if isinstance(module, nn.Linear):
new_module = QLinear(
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/integrations/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
"SpQR (Sparse-Quantized Representation) integration file"

from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging
from ..utils import is_spqr_available, is_torch_available, logging


if is_accelerate_available():
from accelerate import init_empty_weights

if is_torch_available():
import torch
import torch.nn as nn

logger = logging.get_logger(__name__)
Expand All @@ -47,7 +45,7 @@ def replace_with_spqr_linear(model, modules_to_not_convert: list[str] | None = N
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
with init_empty_weights():
with torch.device("meta"):
if isinstance(module, nn.Linear):
shapes = quantization_config.shapes

Expand Down
Loading