Skip to content

gemv_4bit silently produces wrong results when weight is quantized in (in_features, out_features) layout #1862

@TimDettmers

Description

@TimDettmers

Summary

The gemv_4bit CUDA kernel (the fast path for 1D inputs in matmul_4bit) implicitly assumes that quant_state.shape follows the nn.Linear convention of (out_features, in_features). If a weight matrix is quantized in the transposed layout (in_features, out_features), the gemv path silently produces wrong output shape and wrong values — no error is raised.

This does not affect production usage. All real-world code paths (Linear4bit, HuggingFace transformers, etc.) store weights as (out_features, in_features), so the gemv path works correctly. This is a documentation/validation issue for users of the low-level matmul_4bit API.

Root cause

In bitsandbytes/backends/cuda/ops.py, the gemv kernel uses:

# Line 431
shape = (*A.shape[:-1], shapeB[0])  # output shape

# Lines 482-484
m = ct.c_int32(shapeB[0])   # treated as output dimension
k = ct.c_int32(shapeB[1])   # treated as input dimension

This hardcodes shapeB[0] as the output dimension. When quant_state.shape = (out_features, in_features) (the nn.Linear convention), this is correct. When quant_state.shape = (in_features, out_features), the kernel reads the wrong number of input elements and produces the wrong number of output elements.

Meanwhile, MatMul4Bit.forward() (the 2D matrix path) handles both conventions transparently:

output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

It dequantizes, transposes, and lets F.linear handle the shapes. So the matrix path works with any weight layout, but the gemv fast path does not.

Reproduction

import torch
import bitsandbytes as bnb

K, N = 128, 64  # in_features=128, out_features=64
W = torch.randn(K, N, device='cuda', dtype=torch.float16)
B, qs = bnb.functional.quantize_4bit(W, quant_type='nf4')

x = torch.randn(K, device='cuda', dtype=torch.float16)

# gemv path (1D input, no grad) — WRONG results, no error
result_gemv = bnb.matmul_4bit(x, B, qs)

# matrix path (force via requires_grad) — correct results
x2 = x.clone().requires_grad_(True)
result_correct = bnb.matmul_4bit(x2, B, qs)

print(f'gemv shape: {result_gemv.shape}')       # torch.Size([128]) — WRONG, should be 64
print(f'correct shape: {result_correct.shape}')  # torch.Size([64])  — correct

Why it doesn't affect production

  • Linear4bit stores weights as (out_features, in_features), matching the gemv convention
  • HuggingFace and other frameworks go through Linear4bit
  • The existing test_matmul_4bit uses 2D inputs, so it never exercises the gemv path with non-square weights

Possible fix

Add a shape check at the gemv entry point in matmul_4bit() (_functions.py ~line 394):

if A.shape[-1] != quant_state.shape[1]:
    # Weight convention mismatch — fall back to matrix path
    return MatMul4Bit.apply(A, B, out, bias, quant_state)

Or add an assertion that raises an informative error instead of producing silent corruption.

Found during investigation of #1235.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions