Skip to content

Default LLM.int8() mixed-precision decomposition causes 17-147% energy overhead across consumer and datacenter GPUs #1867

@hongping-zh

Description

@hongping-zh

Summary

Through NVML-based power monitoring (10 Hz) across two GPU architectures and 4 models, we found that the default LLM.int8() configuration (llm_int8_threshold=6.0) systematically increases energy consumption by 17-147% compared to FP16 baseline. Ablation experiments confirm the root cause is the mixed-precision decomposition pathway (continuous INT8↔FP16 type conversion for outlier features).

Setting llm_int8_threshold=0.0 eliminates the energy overhead and restores ~80% throughput.

Measured Data

Table 1: Default INT8 vs FP16 — RTX 4090D (Ada Lovelace)

Model FP16 (tok/s) INT8 Default (tok/s) Throughput Loss Energy Δ vs FP16
Yi-1.5-6B 34.72 8.42 −75.7% +32.7%
Mistral-7B 29.06 7.88 −72.9% +30.7%
Phi-3-mini (3.8B) 29.19 13.15 −54.9% +31.2%
Qwen2.5-7B 37.64 9.56 −74.6% +17.4%

Table 2: Default INT8 vs FP16 — A800 (Ampere, datacenter)

Batch Size FP16 (tok/s) INT8 Default (tok/s) Throughput Loss Energy Δ vs FP16
BS=1 36.18 9.87 −72.7% +122%
BS=4 145.35 35.91 −75.3% +147%
BS=8 290.59 69.88 −75.9% +126%

Root Cause: Ablation with llm_int8_threshold=0.0

Disabling the outlier detection (no FP16 fallback) isolates the energy contribution of mixed-precision decomposition:

Table 3: Ablation — RTX 4090D

Model Config Throughput (tok/s) Energy (J/1k tok) Δ Energy vs FP16
Yi-1.5-6B FP16 34.72 4,716
Yi-1.5-6B INT8 Default (threshold=6.0) 8.42 6,258 +32.7%
Yi-1.5-6B INT8 Pure (threshold=0.0) 15.47 4,568 −3.1%
Mistral-7B FP16 29.06 5,661
Mistral-7B INT8 Default (threshold=6.0) 7.88 7,401 +30.7%
Mistral-7B INT8 Pure (threshold=0.0) 14.15 5,212 −7.9%

Table 4: Ablation — A800 (datacenter)

BS Config Throughput (tok/s) Energy Δ vs FP16 Δ vs Default INT8
1 INT8 Default 9.87 +122%
1 INT8 Pure 18.09 +33% −40%
4 INT8 Default 35.91 +147%
4 INT8 Pure 72.96 +44% −42%
8 INT8 Default 69.88 +126%
8 INT8 Pure 144.32 +32% −42%

Key Findings

  1. The energy penalty comes from type conversion, not INT8 arithmetic. When mixed-precision decomposition is disabled, INT8 matches or beats FP16 energy on RTX 4090D.
  2. Cross-architecture consistency: The pattern reproduces across consumer (RTX 4090D) and datacenter (A800) GPUs.
  3. Cross-model consistency: Two different model architectures (Yi-1.5-6B, Mistral-7B) show nearly identical improvement: +79-81% throughput, −34-37% energy.
  4. Batch size amplifies the overhead: A800 at BS=4 shows +147% energy penalty (worst case).

Reproduction

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_name = "mistralai/Mistral-7B-Instruct-v0.2"

# Default INT8 (high energy overhead)
config_default = BitsAndBytesConfig(load_in_8bit=True)
# llm_int8_threshold defaults to 6.0

# Pure INT8 (eliminates overhead)
config_pure = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=0.0,  # Disable outlier detection
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=config_default,  # or config_pure
    device_map="cuda",
)

Power monitoring via pynvml at 10 Hz. Full benchmark scripts and raw data:

Measurement Protocol

  • n=10 per configuration, CV < 1% (throughput), CV < 3% (power)
  • 3 warmup runs, 30s thermal stabilization per model load
  • NVML idle power subtracted (RTX 4090D: ~17W, A800: ~65W)
  • Greedy decoding, 256 output tokens

Discussion

This is not a bug report — the mixed-precision decomposition in LLM.int8() exists for good accuracy reasons (preserving outlier features in FP16). However, the energy implications are significant and may not be obvious to users who choose INT8 expecting energy savings.

Possible actions:

  1. Documentation: Add a note in the README/docs that default INT8 may increase energy consumption due to mixed-precision overhead, and that llm_int8_threshold=0.0 can be used when energy efficiency is prioritized (with accuracy trade-off).
  2. Performance: Investigate whether the type conversion pathway can be optimized to reduce overhead while maintaining accuracy.

Environment

  • RTX 4090D: PyTorch 2.4.1 + CUDA 12.1 + bitsandbytes latest
  • A800: PyTorch 2.x + CUDA 12.x + bitsandbytes latest
  • transformers: 4.47.0
  • Models: Yi-1.5-6B-Chat, Mistral-7B-Instruct-v0.2/v0.3, Phi-3-mini-4k, Qwen2.5-7B

Related


Full paper draft with complete methodology available upon request.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions