Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 86 additions & 23 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.nn import functional as F

if is_accelerate_available():
from accelerate import init_empty_weights
pass


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -409,7 +409,7 @@ def w8a8_block_fp8_matmul_compile(
return output.to(output_dtype)


class FP8Linear(nn.Linear):
class FP8Linear(nn.Module):
def __init__(
self,
in_features: int,
Expand All @@ -419,10 +419,13 @@ def __init__(
block_size: tuple[int, int] | None = None,
activation_scheme="dynamic",
):
super().__init__(in_features, out_features)
super().__init__()

# If block size is None, it means that we are doing per-tensor quantization
self.block_size = block_size
self.dtype = dtype
self.in_features = in_features
self.out_features = out_features
# If block size is None, it means that we are doing per-tensor quantization
self.activation_scheme = activation_scheme

self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
Expand All @@ -444,6 +447,30 @@ def __init__(
else:
self.register_parameter("bias", None)

# Even without this check, initialization is a no-op on meta device
# but it's still a good practice to check
if self.weight.device != torch.device("meta"):
self.reset_parameters()

# TODO: look into other initialization methods for FP8Linear especially for training when we have QAT
def reset_parameters(self) -> None:
"""Initialize weights using Xavier uniform initialization, clamped to FP8 range."""
if self.block_size is None:
self.weight_scale_inv.data.fill_(1.0)
else:
nn.init.ones_(self.weight_scale_inv)

dtype = torch.float8_e4m3fn

# Initialize in float32, clamp to FP8 range, then convert to FP8
# Note: We create a new Parameter because copy_() doesn't change dtype
weight_f32 = torch.empty(self.out_features, self.in_features, dtype=torch.float32, device=self.weight.device)
nn.init.xavier_uniform_(weight_f32)
self.weight = nn.Parameter(weight_f32.clamp(min=torch.finfo(dtype).min, max=torch.finfo(dtype).max).to(dtype))

if self.bias is not None:
nn.init.zeros_(self.bias)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
Expand Down Expand Up @@ -497,6 +524,7 @@ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
from ..activations import ACT2FN

self.block_size = block_size
self.dtype = dtype
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
Expand Down Expand Up @@ -531,6 +559,37 @@ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
# Keep a handle here; actual usage happens in forward of your MoE block
self.act_fn = ACT2FN[config.hidden_act]

# Even without this check, initialization is a no-op on meta device
# but it's still a good practice to check
if self.gate_up_proj.device != torch.device("meta"):
self.reset_parameters()

def reset_parameters(self) -> None:
"""Initialize weights using Xavier uniform initialization, clamped to FP8 range."""
dtype = torch.float8_e4m3fn

# Initialize gate_up_proj in float32, clamp to FP8 range, then convert to FP8
# Note: We create new Parameters because copy_() doesn't change dtype
gate_up_f32 = torch.empty_like(self.gate_up_proj, dtype=torch.float32)
nn.init.xavier_uniform_(gate_up_f32)
self.gate_up_proj = nn.Parameter(
gate_up_f32.clamp(min=torch.finfo(dtype).min, max=torch.finfo(dtype).max).to(dtype)
)

# Initialize down_proj in float32, clamp to FP8 range, then convert to FP8
down_f32 = torch.empty_like(self.down_proj, dtype=torch.float32)
nn.init.xavier_uniform_(down_f32)
self.down_proj = nn.Parameter(down_f32.clamp(min=torch.finfo(dtype).min, max=torch.finfo(dtype).max).to(dtype))

# Initialize scale tensors
nn.init.ones_(self.gate_up_proj_scale_inv)
nn.init.ones_(self.down_proj_scale_inv)

if self.gate_up_proj_bias is not None:
nn.init.zeros_(self.gate_up_proj_bias)
if self.down_proj_bias is not None:
nn.init.zeros_(self.down_proj_bias)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -588,7 +647,10 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to


def replace_with_fp8_linear(
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
model,
modules_to_not_convert: list[str] | None = None,
quantization_config=None,
pre_quantized=False,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.
Expand All @@ -600,7 +662,7 @@ def replace_with_fp8_linear(
Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
quantization_config (`FbgemmFp8Config`):
The quantization config object that contains the quantization parameters.
pre_quantized (`book`, defaults to `False`):
pre_quantized (`bool`, defaults to `False`):
Whether the model is pre-quantized or not
"""

Expand All @@ -614,23 +676,24 @@ def replace_with_fp8_linear(
# we need this to correctly materialize the weights during quantization
module_kwargs = {} if pre_quantized else {"dtype": None}
new_module = None
with init_empty_weights():
if module_name.endswith(".experts"):
new_module = FP8Expert(
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
)
elif isinstance(module, nn.Linear):
new_module = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
**module_kwargs,
)
if new_module is not None:
model.set_submodule(module_name, new_module)
has_been_replaced = True

if module_name.endswith(".experts"):
new_module = FP8Expert(
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
)

elif isinstance(module, nn.Linear):
new_module = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
**module_kwargs,
)
if new_module is not None:
model.set_submodule(module_name, new_module)
has_been_replaced = True

if not has_been_replaced:
logger.warning(
Expand Down
37 changes: 35 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self.dtype_plan = {}
self.quantization_config = kwargs.pop("quantization_config", None)

if isinstance(self._keep_in_fp32_modules, list):
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
Expand Down Expand Up @@ -1495,6 +1496,7 @@ def _from_config(cls, config, **kwargs):
dtype (`torch.dtype`, *optional*):
Override the default `dtype` and load the model under this dtype.
"""

# For BC on the old `torch_dtype`
dtype = kwargs.pop("dtype", config.dtype)
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
Expand All @@ -1508,10 +1510,30 @@ def _from_config(cls, config, **kwargs):
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")

# Handle FP8 quantization config if passed
quantization_config = kwargs.pop("quantization_config", None)
hf_quantizer = None
is_quantized = False

if quantization_config is not None:
quant_method = getattr(quantization_config, "quant_method", None)
# Only FP8 quantization methods are supported in _from_config
if quant_method in (QuantizationMethod.FP8):
hf_quantizer, config, _ = get_hf_quantizer(
config, quantization_config, device_map=None, weights_only=False, user_agent=None
)
else:
logger.warning_once(
f"Quantization method `{quant_method}` is not supported in `_from_config`. "
"Only FP8 quantization methods (`fbgemm_fp8` and `fp8`) are supported. "
"Please use `from_pretrained` for other quantization methods."
)

init_contexts = []
if dtype is not None:
init_contexts.append(local_torch_dtype(dtype, cls.__name__))
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:

if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
Expand All @@ -1523,6 +1545,18 @@ def _from_config(cls, config, **kwargs):
with ContextManagers(init_contexts):
model = cls(config, **kwargs)

# For FP8 quantization, preprocess the model to replace linears with quantized versions
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
keep_in_fp32_modules=model._keep_in_fp32_modules,
config=config,
)

model.hf_quantizer = hf_quantizer
config.quantization_config = quantization_config

return model

@property
Expand Down Expand Up @@ -3974,7 +4008,6 @@ def from_pretrained(
checkpoint_files=checkpoint_files,
use_kernels=use_kernels,
)

# Obtain the weight conversion mapping for this model if any are registered
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,7 @@ def get_hf_quantizer(config, quantization_config, device_map, weights_only, user
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if not getattr(hf_quantizer.quantization_config, "dequantize", False):
quant_method = hf_quantizer.quantization_config.quant_method
user_agent["quant"] = getattr(quant_method, "value", quant_method)
if user_agent is not None:
user_agent["quant"] = getattr(quant_method, "value", quant_method)

return hf_quantizer, config, device_map