-
Notifications
You must be signed in to change notification settings - Fork 547
deprecate xrt_world_size #7679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deprecate xrt_world_size #7679
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use xr.world_size
instead? I don't think we need to rename the function in xla_model
Lines 148 to 152 in 5b8e8e0
def world_size() -> int: | |
"""Returns the total number of processes participating in the job.""" | |
if torch_xla._XLAC._xla_get_replication_devices_count() == 0: | |
return 1 | |
return global_device_count() |
xla/torch_xla/core/xla_model.py Lines 129 to 132 in 5b8e8e0
_WORLD_SIZE will cause recompilation. @alanwaketan for insights.
|
You will know if there is a recompilation from the test. |
Hi @will-cromar , I tried functools.lru_cache and it crashes in multiprocess. I notice that if we use the functools.warps and assign attributes to the functions to be wrapped, it will cause crash. Probably lru_cache uses the func attributes in this case. I create |
torch_xla/runtime.py
Outdated
_ORDINAL = runtime.global_ordinal() | ||
|
||
|
||
def run_once(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this run_once
for? It's a neat idea, but I see you opted for global variables for world size and ordinal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the @run_once
for using_pjrt
, the test test_mp_replication
will fail with build the dynamic graph for dynamo compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why that is? Is it because now all calls to get e.g. world_size
go through functions wrapped in requires_pjrt
, which in turn actually is checking an env var (device_type
and _maybe_select_default_device
)? Whereas before the call would have been stopped by xm.world_size
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check my reply here: #7679 (comment)
@@ -70,6 +101,7 @@ def device_type() -> Optional[str]: | |||
return pjrt_device.split('_')[0] if pjrt_device else pjrt_device | |||
|
|||
|
|||
@run_once | |||
def using_pjrt() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, this function also needs to get deprecated since I assume this is always True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need to call __maybe_select_default_device()
, this is the point why I call using_pjrt() only once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question: if run_once
makes dynamo happy for using_pjrt
, do you know why it does't work for world_size
and global_ordinal
?
torch_xla/runtime.py
Outdated
_ORDINAL = runtime.global_ordinal() | ||
|
||
|
||
def run_once(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why that is? Is it because now all calls to get e.g. world_size
go through functions wrapped in requires_pjrt
, which in turn actually is checking an env var (device_type
and _maybe_select_default_device
)? Whereas before the call would have been stopped by xm.world_size
.
Take unittest
The C API binding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! Filed a follow-up bug to clean up use_pjrt
and requires_pjrt
and fix a concrete usability issue at #7730
TPU CI failure seems relevant, can we fix forward or revert this pr? |
I forgot to update the TPU CI test. Let me make a follow up PR now. |
@@ -36,7 +36,7 @@ dist.init_process_group("xla", rank=rank, world_size=world_size) | |||
|
|||
``` | |||
new_rank = xm.get_ordinal() | |||
world_size = xm.xrt_world_size() | |||
world_size = xr.world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zpcore please add import torch_xla.runtime as xr
to section 1. it feels this line comes from left field without the import in the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xm.world_size()
is also deprecated. They all point to the same thing. We should only use xr.world_size()
.
Deprecate
torch_xla.xla_model.xrt_world_size
and usetorch_xla.runtime.world_size
instead.Add the
run_once
decorator to functionruntime.using_pjrt
since we only need to run this once per process. This helps get rid of dynamo compilation issue with xm.all_reduce.