@@ -192,25 +192,23 @@ def fused_marlin_moe(
192192 is_zp_float = False ,
193193 )
194194
195- if activation_func is not None :
196- activation_func (
197- activation , intermediate_cache2 , intermediate_cache1 .view (- 1 , 2 * N )
198- )
199- else :
200- if activation == "silu" :
201- torch .ops ._C .silu_and_mul (
202- intermediate_cache2 , intermediate_cache1 .view (- 1 , 2 * N )
203- )
204- elif activation == "swigluoai" :
205- # alpha = 1.702, limit = 7.0
206- torch .ops ._C .swigluoai_and_mul (
207- intermediate_cache2 , intermediate_cache1 .view (- 1 , 2 * N )
208- )
209- else :
210- raise ValueError (
211- f"Unsupported activation: { activation } . "
212- "Only silu and swigluoai activations are supported."
213- )
195+ if activation_func is None :
196+ def activation_func (activation :str , output :torch .Tensor , input :torch .Tensor ) -> None :
197+ if activation == "silu" :
198+ torch .ops ._C .silu_and_mul (
199+ output , input
200+ )
201+ elif activation == "swigluoai" :
202+ # alpha = 1.702, limit = 7.0
203+ torch .ops ._C .swigluoai_and_mul (
204+ output , input
205+ )
206+ else :
207+ raise ValueError (
208+ f"Unsupported activation: { activation } . "
209+ "Only silu and swigluoai activations are supported."
210+ )
211+ activation_func (activation , intermediate_cache2 , intermediate_cache1 .view (- 1 , 2 * N ))
214212
215213 if expert_map is not None :
216214 intermediate_cache3 .zero_ ()
@@ -425,7 +423,6 @@ def apply(
425423 def moe_sum (self , input : torch .Tensor , output : torch .Tensor ) -> None :
426424 ops .moe_sum (input , output )
427425
428-
429426def modular_marlin_fused_moe (
430427 quant_config : FusedMoEQuantConfig , shared_experts : Optional [torch .nn .Module ] = None
431428) -> mk .FusedMoEModularKernel :
0 commit comments