Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix conflict between Tutel and top-2 gate in MoE layer #2053

Merged
merged 6 commits into from
Jul 26, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,17 @@ def __init__(self,
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False

self.use_tutel = use_tutel and TUTEL_INSTALLED
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1

if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this in a if torch.distributed.get_rank() ==0:?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is possible. But I wonder should we also wrap other warnings and infos? For example, L480 and L482-483?

"To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")

def _set_ep_group(self, ep_group):
self.ep_group = ep_group
Expand Down