Skip to content

int8 quantization with FSDP for inference error #2127

Open
@Andy0422

Description

@Andy0422

Can FSDP work with torchao in inference?

I would like to employ the torchao to get int8 model, and with FSDP to save memory.

The following code is a tiny toy to test this goal,

`import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, float8_dynamic_activation_float8_weight, int8_weight_only
import copy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard

class FFN(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(FFN, self).init()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
    return x

weight_path = "xxx/ffn_weights.pth"

dist.init_process_group(backend='nccl')

input_dim = 10
hidden_dim = 20
output_dim = 10

base_model = FFN(input_dim, hidden_dim, output_dim).to(torch.cuda.current_device())
base_model.load_state_dict(torch.load(weight_path))
print("model structure", base_model)
fsdp_model = copy.deepcopy(base_model)

for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")

quantize_(base_model, int8_dynamic_activation_int8_weight())

print("q_model", base_model)

from torchao.quantization.quant_api import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(base_model, Int8DynamicActivationInt8WeightConfig())
print("q_model_new_api", base_model)

for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
# print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
# setattr(model, name, quantize_(module, int8_dynamic_activation_int8_weight()))
print(f"after quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")

for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")

for param in base_model.parameters():
param.requires_grad = False

for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")

wrap_policy = ModuleWrapPolicy({nn.Linear})

model = FSDP(base_model, auto_wrap_policy=wrap_policy,

use_orig_params=True)

fully_shard(base_model)
`

then, get the error,

[rank1]: Traceback (most recent call last):
[rank1]: File "fsdp_test.py", line 77, in
[rank1]: fully_shard(base_model)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 129, in fully_shard
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in init
[rank1]: self.fsdp_params = [
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in
[rank1]: FSDPParam(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 226, in init
[rank1]: self._init_sharded_param(param, device)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 310, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fsdp_common.py", line 94, in chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: File "/ao/torchao/utils.py", line 425, in dispatch__torch_function

[rank1]: return func(*args, **kwargs)
[rank1]: File "/ao/torchao/utils.py", line 444, in dispatch__torch_dispatch

[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions