Skip to content

Commit 75328ea

Browse files
use filter_fn in quantize_
1 parent 9325c13 commit 75328ea

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
129129
return True
130130
return False
131131

132-
config = MoETrainingConfig(module_filter_fn=moe_module_filter_fn)
133-
quantize_(model, config=config)
132+
config = MoETrainingConfig()
133+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
134134
logger.info("Converted MoE to float8")
135135

136136
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):

torchtitan/experiments/llama4/model/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def forward(
8282
assert (
8383
x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
8484
), "torch._grouped_mm only supports bf16 dtypes"
85+
8586
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
8687
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
8788
out = torch._grouped_mm(h, self.w2, offs=offsets)

0 commit comments

Comments
 (0)