Skip to content

Commit

Permalink
Fix conflict between Tutel and top-2 gate in MoE layer (#2053)
Browse files Browse the repository at this point in the history
* fix conflit between tutel and top-2 gate
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
yetiansh authored Jul 26, 2022
1 parent 0e49b19 commit 31582d7
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,17 @@ def __init__(self,
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False

self.use_tutel = use_tutel and TUTEL_INSTALLED
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1

if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning(
"To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")

def _set_ep_group(self, ep_group):
self.ep_group = ep_group
Expand Down

0 comments on commit 31582d7

Please sign in to comment.