Skip to content

Int8DynActInt4WeightQATQuantizer doesn't support qwen series #1080

Open
@elfisworking

Description

@elfisworking

i use Int8DynActInt4WeightQATQuantizer to quantize qwen2-1.5B model. But after prepare function, i find that bias is set to False.
This is my Code

from torchtune.models.qwen2 import qwen2_1_5b
model = qwen2_1_5b()
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
print("before prepare: ", model)
model = qat_quantizer.prepare(model)
print("after prepare: ", model)

The output is

before prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
        (k_proj): Linear(in_features=1536, out_features=256, bias=True)
        (v_proj): Linear(in_features=1536, out_features=256, bias=True)
        (output_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Linear(in_features=1536, out_features=8960, bias=False)
        (w2): Linear(in_features=8960, out_features=1536, bias=False)
        (w3): Linear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)
after prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (k_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (v_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (output_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (w2): Int8DynActInt4WeightQATLinear(in_features=8960, out_features=1536, bias=False)
        (w3): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)

we can see that after prepare function, (q_proj): Linear(in_features=1536, out_features=1536, bias=True) has been (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
From torchao code, we can see In function

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
        new_linear = linear_class(
                    child.in_features,
                    child.out_features,
                    bias=False,
                    device=child.weight.device,
                    groupsize=groupsize,
                    precision=precision,
                    scales_precision=scales_precision,
                )

bias is set to False.
So has any Solution about this problem ?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions