Skip to content

Commit 5f2c5b9

Browse files
authored
fix moe apis (#41650)
1 parent d95280c commit 5f2c5b9

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ def forward(self, x):
6262
if self.random_routing:
6363
rand_routing_prob = paddle.rand(
6464
shape=[gate_score.shape[0]], dtype="float32")
65-
topk_idx = paddle.distributed.utils.random_routing(
65+
topk_idx = paddle.distributed.models.moe.utils._random_routing(
6666
topk_idx, topk_val, rand_routing_prob)
6767
return topk_val, topk_idx

python/paddle/incubate/distributed/models/moe/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from paddle.distributed.models.moe.utils import *
14+
from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos
15+
import paddle
1516

1617

1718
def _alltoall(in_tensor_list, group=None, use_calc_stream=True):

0 commit comments

Comments
 (0)