Skip to content

(Partial) Bug fix for test_xnnpack_dq4_kv_fp32_llama #2691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/models/llama2/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,16 @@ def convert_for_runtime(self) -> nn.Module:

def replace_linear_weight_only_int8_per_channel(module, node_type):
for name, child in module.named_children():
print(f"name: {name}")
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(node_type == "*")
or (node_type == "output" and name == "output")
or (node_type == "!output" and name != "output")
):
print(f"{name, child}")
print(f"in_features: {child.in_features}")
print(f"out_features: {child.out_features}")
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")
setattr(
module,
name,
Expand Down Expand Up @@ -276,10 +276,10 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
module, bitwidth: int = 8, group_size: Optional[int] = None
):
for name, child in module.named_children():
print(f"name: {name}")
# print(f"name: {name}")
if isinstance(child, nn.Embedding):
print(f"{name, child}")
print(f"weights size: {child.weight.size()}")
# print(f"{name, child}")
# print(f"weights size: {child.weight.size()}")
setattr(
module,
name,
Expand Down Expand Up @@ -320,9 +320,9 @@ def create_quantized_state_dict(self) -> Dict:
or isinstance(mod, fsEmbedding)
or isinstance(mod, fsStandardEmbedding)
):
print("****")
print(f"Embedding identified: {fqn, mod}")
print(f"weights size: {mod.weight.size()}")
# print("****")
# print(f"Embedding identified: {fqn, mod}")
# print(f"weights size: {mod.weight.size()}")
# print(f"quantize {fqn}...")

print(
Expand Down Expand Up @@ -516,9 +516,9 @@ def create_quantized_state_dict(self):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
print("in features:", in_features, " out features:", out_features)
# print("in features:", in_features, " out features:", out_features)
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.group_size == 0
Expand Down