Description
Can FSDP work with torchao in inference?
I would like to employ the torchao to get int8 model, and with FSDP to save memory.
The following code is a tiny toy to test this goal,
`import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, float8_dynamic_activation_float8_weight, int8_weight_only
import copy
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
class FFN(nn.Module):
def init(self, input_dim, hidden_dim, output_dim):
super(FFN, self).init()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
weight_path = "xxx/ffn_weights.pth"
dist.init_process_group(backend='nccl')
input_dim = 10
hidden_dim = 20
output_dim = 10
base_model = FFN(input_dim, hidden_dim, output_dim).to(torch.cuda.current_device())
base_model.load_state_dict(torch.load(weight_path))
print("model structure", base_model)
fsdp_model = copy.deepcopy(base_model)
for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
quantize_(base_model, int8_dynamic_activation_int8_weight())
print("q_model", base_model)
from torchao.quantization.quant_api import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(base_model, Int8DynamicActivationInt8WeightConfig())
print("q_model_new_api", base_model)
for name, module in base_model.named_modules():
if isinstance(module, nn.Linear):
# print(f"Before quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
# setattr(model, name, quantize_(module, int8_dynamic_activation_int8_weight()))
print(f"after quantization: {name}, Size = {module.weight.size()}, Stride = {module.weight.stride()}")
for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")
for param in base_model.parameters():
param.requires_grad = False
for name, param in base_model.named_parameters():
print(f"Parameter: {name}, requires_grad={param.requires_grad}")
wrap_policy = ModuleWrapPolicy({nn.Linear})
model = FSDP(base_model, auto_wrap_policy=wrap_policy,
use_orig_params=True)
fully_shard(base_model)
`
then, get the error,
[rank1]: Traceback (most recent call last):
[rank1]: File "fsdp_test.py", line 77, in
[rank1]: fully_shard(base_model)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 129, in fully_shard
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 114, in init
[rank1]: self.fsdp_params = [
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 115, in
[rank1]: FSDPParam(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 226, in init
[rank1]: self._init_sharded_param(param, device)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 310, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fsdp_common.py", line 94, in chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: File "/ao/torchao/utils.py", line 425, in dispatch__torch_function
[rank1]: return func(*args, **kwargs)
[rank1]: File "/ao/torchao/utils.py", line 444, in dispatch__torch_dispatch
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}