File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -129,8 +129,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
129
129
return True
130
130
return False
131
131
132
- config = MoETrainingConfig (module_filter_fn = moe_module_filter_fn )
133
- quantize_ (model , config = config )
132
+ config = MoETrainingConfig ()
133
+ quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
134
134
logger .info ("Converted MoE to float8" )
135
135
136
136
def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]):
Original file line number Diff line number Diff line change @@ -82,6 +82,7 @@ def forward(
82
82
assert (
83
83
x .dtype == self .w1 .dtype == self .w2 .dtype == self .w3 .dtype == torch .bfloat16
84
84
), "torch._grouped_mm only supports bf16 dtypes"
85
+
85
86
h = F .silu (torch ._grouped_mm (x , self .w1 , offs = offsets ))
86
87
h = h * torch ._grouped_mm (x , self .w3 , offs = offsets )
87
88
out = torch ._grouped_mm (h , self .w2 , offs = offsets )
You can’t perform that action at this time.
0 commit comments