Skip to content

Commit e7f0ba4

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix preq embedding dtype check (#10699)
Summary: Rollback Plan: Differential Revision: D74194622
1 parent 1c2b7ba commit e7f0ba4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/models/llama/source_transformation/pre_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
146146
scales_key = f"{cur_fqn}.scales"
147147
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
148148
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
149-
assert checkpoint[scales_key].dtype == torch.float32
149+
assert checkpoint[scales_key].dtype == dtype
150150
return True
151151
return False
152152

0 commit comments

Comments
 (0)