Description
Hi,
I did some benchmark on LLM models with int4_weight_only
on CPU/GPU/XPU and expected to see models have E2E speed up compared with pure bf16/fp16.
From the aspect of kernel, int4 GEMM kernels are 2x~3x faster compared with bf16/fp16 GEMM in general.
However, I did not see E2E performance improvement and even slowdown in some models.
After profiling, I found that Torchao's CPU overhead is too high, and it might be higher than the time saved from int4 GEMM kernel. The reason is that Torchao uses Tensor subclass
and __torch_function__
to redispatch nn.linear to custom int4 matmul op.
- In eager mode, the dispatching time (the blue square showed as below) takes longer than aten::_weight_int4pack_mm_cpu. I thought that
torch.compile
could optimize these redispatching things.
- However, in compile mode, redispatching disappears in compiled model.forward region but it introduces extra host work in dynamo/inductor. All
torchao/dtypes/affine_quantized_tensor.py
will be flattened bytorch/_functorch/_aot_autograd/subclass_utils.py(233): flatten_subclass
and the flatten time is close to the time of aten::_weight_int4pack_mm_cpu
Both eager mode and compile mode suffer from these device-agnostic Torchao's CPU overhead, which may counteract the performance benefit we get from int4 GEMM in host bound model, such as small models or GPU/XPU is too fast (we meet this issue with Qwen2-0.5b, Phi3-3.8b from huggingface on Nvidia A100 GPU and Intel Data Center GPU Max Series).
Could you optimize these CPU overhead?
Reproducer
import torch
from torchao.quantization.quant_api import (
int4_weight_only,
quantize_,
)
class Linear_Gate_Up(torch.nn.Module):
def __init__(self, in_feature, out_feature):
super(Linear_Gate_Up, self).__init__()
self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=False)
self.gate_proj2 = torch.nn.Linear(out_feature, out_feature, bias=False)
def forward(self, x):
return self.gate_proj2(self.gate_proj(x))
if __name__ == "__main__":
# device = "cpu"
device = "cuda"
# device = "xpu"
quantize_model = True
compile_model = True
run_with_profiler = False
with torch.no_grad():
model = Linear_Gate_Up(512, 1024).eval().to(device).to(torch.bfloat16)
x = torch.randn(1, 512).to(device).to(torch.bfloat16)
if quantize_model:
if device == "cpu":
from torchao.dtypes import Int4CPULayout
quantize_(model, int4_weight_only(layout=Int4CPULayout()))
elif device == "cuda":
quantize_(model, int4_weight_only())
elif device == "xpu":
from torchao.dtypes import Int4XPULayout
quantize_(model, int4_weight_only(layout=Int4XPULayout()))
if compile_model:
model = torch.compile(model)
# warmup run the actual torch.compile
model(x)
if device == "cuda":
torch.cuda.synchronize()
if device == "xpu":
torch.xpu.synchronize()
latencies = []
import time, contextlib
with (torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
torch.profiler.ProfilerActivity.XPU],
with_stack=True)
if run_with_profiler else contextlib.nullcontext()) as prof:
for i in range(10):
start = time.time()
model(x)
if device == "cuda":
torch.cuda.synchronize()
if device == "xpu":
torch.xpu.synchronize()
if i > 3: # warmup
latencies.append(time.time() - start)
print("Latency: {} ms".format(sum(latencies) / len(latencies) * 1000))
if run_with_profiler:
prof.export_chrome_trace("./{}.json".format(device))