Skip to content

[DeepSeek-V3/R1] Anyone had success quantizing DeepSeek-V3 using llm-compressor? #1482

Closed
@ashgold

Description

@ashgold

Describe the bug
A clear and concise description of what the bug is.

Expected behavior
A clear and concise description of what you expected to happen.

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 20.04]: Ubuntu 22.04
  2. Python version [e.g. 3.7]: 3.12.9
  3. LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: v0.5.1
  4. ML framework version(s) [e.g. torch 2.3.1]: torch==2.5.1
  5. Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]: compressed-tensors v0.9.1
  6. Other relevant environment information [e.g. hardware, CUDA version]:

Has anyone succeeded in quantizing DeepSeek-V3 using llm-compressor?
I tried to quantize DeepSeek-V3 using 8xH100 nodes, referring to all the issues on github, but kept getting OOMs.

I used mit-han-lab/pile-val-backup dataset (512 samples and 2048 sequence length).

Below are my attempts. All attempts resulted in an OOM.

Preparing intermediates cache: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:02<00:00, 186.39it/s]
(1/61): Calibrating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:01<00:00, 286.37it/s]
2025-05-27T15:58:31.207270-0700 | on_sequential_batch_end | INFO - Quantizing model.layers.3.mlp.experts.0.gate_proj using 1 samples
2025-05-27T15:58:31.213474-0700 | compress | METRIC - time 0.01s
2025-05-27T15:58:31.213871-0700 | compress | METRIC - GPU 0 | usage: 46.09% | total memory: 85 GB
2025-05-27T15:58:31.213931-0700 | compress | METRIC - GPU 1 | usage: 99.97% | total memory: 85 GB
2025-05-27T15:58:31.213964-0700 | compress | METRIC - GPU 2 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.213989-0700 | compress | METRIC - GPU 3 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.214014-0700 | compress | METRIC - GPU 4 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.214036-0700 | compress | METRIC - GPU 5 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.214056-0700 | compress | METRIC - GPU 6 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.214074-0700 | compress | METRIC - GPU 7 | usage: 29.38% | total memory: 85 GB
2025-05-27T15:58:31.214124-0700 | compress | METRIC - Compressed module size: 29.704192 MB
Traceback (most recent call last):
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 53, in forward
    outputs = forward_fn(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 7, in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py", line 504, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py", line 215, in forward
    hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py", line 201, in moe
    expert_output = expert(expert_input)
                    ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/transformers/models/deepseek_v3/modeling_deepseek_v3.py", line 114, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                                                                ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1844, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1803, in inner
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/utils/hooks.py", line 93, in wrapped_hook
    return hook(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 322, in calibrate_module
    self._hessians[module], self._num_samples[module] = accumulate_hessian(
                                                        ^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py", line 66, in accumulate_hessian
    H += inp.matmul(inp.t())
         ^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 196.00 MiB. GPU 1 has a total capacity of 79.11 GiB of which 26.88 MiB is free. Process 549622 has 79.07 GiB memory in use. Of the allocated memory 78.28 GiB is allocated by PyTorch, and 136.10 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
 
The above exception was the direct cause of the following exception:
 
Traceback (most recent call last):
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 234, in on_initialize
    run_sequential(
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/pipelines/sequential/pipeline.py", line 67, in run_pipeline
    subgraph.forward(model, **inputs)
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 55, in forward
    raise RuntimeError(
RuntimeError: Raised an exception during execution of the following code:

1
2
3
4 def forward(self, unsqueeze, arange, getitem_6, model_rotary_emb, clone, masked_fill, model_layers_2):
5 clone[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(None, getitem_6, None))] = masked_fill; setitem = clone; getitem_6 = masked_fill = setitem = None
6 getitem_12 = model_layers_2[0]; model_layers_2 = None
7 model_layers_3 = getattr(self.model.layers, "3")(getitem_12, attention_mask = clone, position_ids = unsqueeze, past_key_value = None, output_attentions = False, use_cache = False, cache_position = arange, position_embeddings = model_rotary_emb); getitem_12 = clone = unsqueeze = arange = model_rotary_emb = None
8 return {'model_layers_3': model_layers_3}
9

This is likely due to a violation of shape assumptions made when tracing
 
During handling of the above exception, another exception occurred:
 
Traceback (most recent call last):
  File "/home/jovyan/workspace/awq/quantize_gptq.py", line 103, in <module>
     
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/entrypoints/oneshot.py", line 179, in oneshot
    one_shot()
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/entrypoints/oneshot.py", line 131, in __call__
    self.apply_recipe_modifiers(
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/entrypoints/oneshot.py", line 173, in apply_recipe_modifiers
    session.initialize(**session_kwargs)
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/core/session.py", line 118, in initialize
    mod_data = self._lifecycle.initialize(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/core/lifecycle.py", line 105, in initialize
    data = mod.initialize(state=self.state, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/stage.py", line 83, in initialize
    modifier.initialize(state, **kwargs)
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/modifier.py", line 90, in initialize
    self.initialized_ = self.on_initialize(state=state, **kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 268, in on_initialize
    raise exception
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 256, in on_initialize
    run_layer_sequential(
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/pipelines/layer_sequential/pipeline.py", line 73, in run_pipeline
    callback_modifier.on_sequential_batch_end()
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 345, in on_sequential_batch_end
    loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
                                                       ^^^^^^^^^^^^^^^^
  File "/home/jovyan/.local/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py", line 108, in quantize_weight
    W = W.to(dtype=GPTQ_PRECISION)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 56.00 MiB. GPU 1 has a total capacity of 79.11 GiB of which 26.88 MiB is free. Process 549622 has 79.07 GiB memory in use. Of the allocated memory 78.31 GiB is allocated by PyTorch, and 108.11 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

calculate_offload_device_map

if reserve_for_hessians=True, We couldn't use GPU.

device_map = calculate_offload_device_map(
     MODEL_ID, num_gpus=8, reserve_for_hessians=False, torch_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16)

recipe = [
    GPTQModifier(
        ignore=["lm_head", "re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\\.(gate|up|gate_up|down)_proj.*"],
        targets=["Linear"],
        scheme="W4A16",
        group_size=128,
        offload_hessians=False #offloading hessian makes calibration process terribly slow.
    ),
]

40GiB * 8 H100s

with init_empty_weights():
    dummy_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    device_map = infer_auto_device_map(
        dummy_model,
        offload_buffers=True,
        max_memory={
            0: "40GiB",
            1: "40GiB",
            2: "40GiB",
            3: "40GiB",
            4: "40GiB",
            5: "40GiB",
            6: "40GiB",
            7: "40GiB",
            "cpu": "1600GiB"
        },
        no_split_module_classes=dummy_model._no_split_modules,
    )
    del dummy_model
 
 
recipe = [
    GPTQModifier(
        ignore=["lm_head", "re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\\.(gate|up|gate_up|down)_proj.*"],
        targets=["Linear"],
        scheme="W4A16",
        group_size=128,
        offload_hessians=False
    ),
]

10GiB * 8 H100s

with init_empty_weights():
    dummy_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    device_map = infer_auto_device_map(
        dummy_model,
        offload_buffers=True,
        max_memory={
            0: "10GiB",
            1: "10GiB",
            2: "10GiB",
            3: "10GiB",
            4: "10GiB",
            5: "10GiB",
            6: "10GiB",
            7: "10GiB",
            "cpu": "1600GiB"
        },
        # no_split_module_classes=dummy_model._no_split_modules,
    )
    del dummy_model
 
# Select model and load it.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16)
print(f"hf_device_map: {model.hf_device_map}")
 
 
# layers not to convert reference: https://huggingface.co/ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts/blob/main/config.json
# Configure the quantization algorithm to run.
recipe = [
    GPTQModifier(
        ignore=["lm_head", "re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\\.(gate|up|gate_up|down)_proj.*"],
        targets=["Linear"],
        scheme="W4A16",
        group_size=128,
        offload_hessians=False
    ),
]

dampening_frac 0.1

with init_empty_weights():
    dummy_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    device_map = infer_auto_device_map(
        dummy_model,
        # offload_buffers=True,
        # fallback_allocation=True,
        max_memory={
            0: "30GiB",
            1: "30GiB",
            2: "30GiB",
            3: "30GiB",
            4: "30GiB",
            5: "30GiB",
            6: "30GiB",
            7: "30GiB",
            "cpu": "1600GiB"
        },
        no_split_module_classes=dummy_model._no_split_modules,
    )
    del dummy_model
 
recipe = [
    GPTQModifier(
        ignore=["lm_head", "re:.*self_attn.*", "re:.*shared_experts.*", "re:.*mlp\\.(gate|up|gate_up|down)_proj.*"],
        targets=["Linear"],
        scheme="W4A16",
        group_size=128,
        dampening_frac=0.1, #it can reduce memory for hessian.
        offload_hessians=False #offloading hessian makes calibration process terribly slow.
    ),
]

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions