Skip to content

Skip dW/dB computation in norm ops when weight/bias is frozen (LoRA/PEFT optimization) #1067

@yukiu00

Description

@yukiu00

🚀 The feature, motivation and pitch

Problem Statement

When using LoRA, PEFT, or other parameter-efficient fine-tuning methods, the base model's normalization layer weights are typically frozen (requires_grad=False). However, Liger's norm ops currently compute gradients for these frozen parameters unconditionally, resulting in:

  1. Wasted computation: The backward pass computes dW/dB even when they will be discarded
  2. Unnecessary memory allocation: Temporary buffers for gradient accumulation are allocated but never used
  3. Suboptimal training throughput: Especially noticeable at large hidden sizes (e.g., 8K-32K in modern LLMs)

This is particularly relevant as LoRA/PEFT adoption has become the de facto standard for fine-tuning large language models.

Affected Operations

  • RMSNorm
  • FusedAddRMSNorm
  • LayerNorm
  • GroupNorm
  • PolyNorm

Proposed Solution

Leverage PyTorch's ctx.needs_input_grad in the backward pass to conditionally skip:

  1. Weight/bias gradient computation in the Triton kernel (compute_dW, compute_dB flags)
  2. Temporary buffer allocation for gradient accumulation

This approach:

  • Requires no public API changes
  • Is fully backward compatible (unfrozen weights work exactly as before)
  • Automatically benefits all existing LoRA/PEFT users without code changes

Benchmark Results

Environment: RTX 3090, bf16, M=2048 (batch × seq_len)

RMSNorm Only (freeze_weight=True)

Hidden Size Backward Speedup Full (fwd+bwd) Speedup
H=1024 1.25× (−20.1%) 1.12× (−10.3%)
H=2048 1.15× (−12.8%) 1.09× (−8.3%)
H=4096 1.11× (−10.1%) 1.05× (−4.7%)
H=8192 1.07× (−6.2%) 1.04× (−4.2%)
H=16384 1.37× (−27.1%) 1.22× (−18.1%)
H=32768 3.12× (−67.9%) 2.41× (−58.5%)

The speedup increases significantly at larger hidden sizes because the dW reduction (summing partial gradients across SMs) becomes the dominant cost.

Mixed Workload: RMSNorm + LoRA Linear (freeze_norm_weight=True)

Hidden Size Backward Full
H=1024–32768 1.00×–1.05× 1.00×–1.04×

In realistic LoRA scenarios, the linear layers dominate runtime, so the norm optimization provides modest but consistent gains.

Implementation Details

Internal API changes (not public-facing):

rms_norm_backward(..., compute_dW: bool)
fused_add_rms_norm_backward(..., compute_dW: bool)
layer_norm_backward(..., compute_dW: bool, compute_dB: bool)
group_norm_backward(..., compute_dW: bool, compute_dB: bool)
poly_norm_backward(..., compute_dW: bool, compute_dB: bool)

Kernel changes:

  • Added compute_dW/compute_dB as tl.constexpr parameters, enabling Triton to eliminate dead code at compile time
  • Skip buffer allocation when gradients are not needed

Why This Matters

  1. Growing LoRA/PEFT adoption: Most LLM fine-tuning now uses parameter-efficient methods
  2. Larger models = bigger impact: Modern LLMs use hidden sizes of 4K–16K+, where this optimization shines
  3. Zero user effort: Existing code automatically benefits
  4. Memory savings: Reduced temporary buffer allocation helps with tight GPU memory budgets

Reproduction

# Run benchmarks
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm.py --overwrite
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm_mixed.py --overwrite

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions