Skip to content

Torchao's CPU overhead counteracts the performance benefit of using quantization kernel. #1930

Open
@LuFinch

Description

@LuFinch

Hi,

I did some benchmark on LLM models with int4_weight_only on CPU/GPU/XPU and expected to see models have E2E speed up compared with pure bf16/fp16.

From the aspect of kernel, int4 GEMM kernels are 2x~3x faster compared with bf16/fp16 GEMM in general.

However, I did not see E2E performance improvement and even slowdown in some models.

After profiling, I found that Torchao's CPU overhead is too high, and it might be higher than the time saved from int4 GEMM kernel. The reason is that Torchao uses Tensor subclass and __torch_function__ to redispatch nn.linear to custom int4 matmul op.

  • In eager mode, the dispatching time (the blue square showed as below) takes longer than aten::_weight_int4pack_mm_cpu. I thought that torch.compile could optimize these redispatching things.

Image

  • However, in compile mode, redispatching disappears in compiled model.forward region but it introduces extra host work in dynamo/inductor. All torchao/dtypes/affine_quantized_tensor.py will be flattened by torch/_functorch/_aot_autograd/subclass_utils.py(233): flatten_subclass and the flatten time is close to the time of aten::_weight_int4pack_mm_cpu

Image

Both eager mode and compile mode suffer from these device-agnostic Torchao's CPU overhead, which may counteract the performance benefit we get from int4 GEMM in host bound model, such as small models or GPU/XPU is too fast (we meet this issue with Qwen2-0.5b, Phi3-3.8b from huggingface on Nvidia A100 GPU and Intel Data Center GPU Max Series).

Could you optimize these CPU overhead?

Reproducer

import torch
from torchao.quantization.quant_api import (
    int4_weight_only,
    quantize_,
)

class Linear_Gate_Up(torch.nn.Module):
    def __init__(self, in_feature, out_feature):
        super(Linear_Gate_Up, self).__init__()
        self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=False)
        self.gate_proj2 = torch.nn.Linear(out_feature, out_feature, bias=False)
 
    def forward(self, x):
        return self.gate_proj2(self.gate_proj(x))
 
if __name__ == "__main__":
    # device = "cpu"
    device = "cuda"
    # device = "xpu"
    quantize_model = True
    compile_model = True
    run_with_profiler = False

    with torch.no_grad():
        model = Linear_Gate_Up(512, 1024).eval().to(device).to(torch.bfloat16)
        x = torch.randn(1, 512).to(device).to(torch.bfloat16)
        if quantize_model:
            if device == "cpu":
                from torchao.dtypes import Int4CPULayout
                quantize_(model, int4_weight_only(layout=Int4CPULayout()))
            elif device == "cuda":
                quantize_(model, int4_weight_only())
            elif device == "xpu":
                from torchao.dtypes import Int4XPULayout
                quantize_(model, int4_weight_only(layout=Int4XPULayout()))

        if compile_model:
            model = torch.compile(model)
          
        # warmup run the actual torch.compile
        model(x)
        if device == "cuda":
            torch.cuda.synchronize()
        if device == "xpu":
            torch.xpu.synchronize()
        
        latencies = []
        import time, contextlib
        with (torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU,
                                torch.profiler.ProfilerActivity.CUDA,
                                torch.profiler.ProfilerActivity.XPU],
                    with_stack=True) 
            if run_with_profiler else contextlib.nullcontext()) as prof:
            for i in range(10):
                start = time.time()
                model(x)
                if device == "cuda":
                    torch.cuda.synchronize()
                if device == "xpu":
                    torch.xpu.synchronize()
                if i > 3: # warmup
                    latencies.append(time.time() - start)

        print("Latency: {} ms".format(sum(latencies) / len(latencies) * 1000))
        if run_with_profiler:
            prof.export_chrome_trace("./{}.json".format(device))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions