-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[TPU] update torch_xla pin #19231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] update torch_xla pin #19231
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,7 +100,8 @@ def init_device(self): | |
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to | ||
# fix this. It will be removed after the bug in XLA compiler is fixed. | ||
os.environ["LIBTPU_INIT_ARGS"] = ( | ||
"--xla_tpu_force_1d_allreduce_at_chunk_count=1") | ||
os.environ.get("LIBTPU_INIT_ARGS", "") + | ||
" --xla_tpu_force_1d_allreduce_at_chunk_count=1") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: We can add a comment saying the additional libtpu arg is needed due to pytorch/xla#9084 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine. Because here we're not adding any specific libtpu arg, but inherit all the args if any. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense, thanks! |
||
torch.set_grad_enabled(False) | ||
torch.set_default_dtype(self.model_config.dtype) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a test case where
k
is chosen such thatnum_tokens * topk
is a multiple of 16, given the constraint mentioned in line 27. This would provide more confidence in the kernel's correctness under the required condition.