@@ -28,8 +28,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
28
28
29
29
@abstractmethod
30
30
def apply (self , layer : torch .nn .Module , x : torch .Tensor ,
31
- topk_weights : torch .Tensor ,
32
- topk_ids : torch . Tensor ) -> torch .Tensor :
31
+ topk_weights : torch .Tensor , topk_ids : torch . Tensor ,
32
+ ** kwargs ) -> torch .Tensor :
33
33
raise NotImplementedError
34
34
35
35
@@ -59,21 +59,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
59
59
set_weight_attrs (w2_weight , extra_weight_attrs )
60
60
61
61
def apply (self , layer : torch .nn .Module , x : torch .Tensor ,
62
- topk_weights : torch .Tensor ,
63
- topk_ids : torch . Tensor ) -> torch .Tensor :
62
+ topk_weights : torch .Tensor , topk_ids : torch . Tensor ,
63
+ ** kwargs ) -> torch .Tensor :
64
64
65
65
return self .forward (x = x ,
66
66
layer = layer ,
67
67
topk_weights = topk_weights ,
68
- topk_ids = topk_ids )
68
+ topk_ids = topk_ids ,
69
+ ** kwargs )
69
70
70
- def forward_cuda (
71
- self ,
72
- layer : torch .nn .Module ,
73
- x : torch .Tensor ,
74
- topk_weights : torch .Tensor ,
75
- topk_ids : torch .Tensor ,
76
- ) -> torch .Tensor :
71
+ def forward_cuda (self , layer : torch .nn .Module , x : torch .Tensor ,
72
+ topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
73
+ ** kwargs ) -> torch .Tensor :
77
74
return fused_experts (hidden_states = x ,
78
75
w1 = layer .w13_weight ,
79
76
w2 = layer .w2_weight ,
@@ -85,17 +82,11 @@ def forward_cpu(self, *args, **kwargs):
85
82
raise NotImplementedError (
86
83
"The CPU backend currently does not support MoE." )
87
84
88
- def forward_tpu (
89
- self ,
90
- layer : torch .nn .Module ,
91
- x : torch .Tensor ,
92
- topk_weights : torch .Tensor ,
93
- topk_ids : torch .Tensor ,
94
- ) -> torch .Tensor :
95
-
96
- #assert not use_grouped_topk
97
- #assert num_expert_group is None
98
- #assert topk_group is None
85
+ def forward_tpu (self , layer : torch .nn .Module , x : torch .Tensor ,
86
+ topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
87
+ use_grouped_topk : bool ) -> torch .Tensor :
88
+
89
+ assert not use_grouped_topk
99
90
return fused_experts (hidden_states = x ,
100
91
w1 = layer .w13_weight ,
101
92
w2 = layer .w2_weight ,
@@ -294,7 +285,8 @@ def forward(self, hidden_states: torch.Tensor,
294
285
layer = self ,
295
286
x = hidden_states ,
296
287
topk_weights = topk_weights ,
297
- topk_ids = topk_ids )
288
+ topk_ids = topk_ids ,
289
+ use_grouped_topk = self .use_grouped_topk )
298
290
299
291
# Optionally reduce.
300
292
if self .reduce_results and self .tp_size > 1 :
0 commit comments