Skip to content

deprecate xrt_world_size #7679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 24, 2024
Merged

deprecate xrt_world_size #7679

merged 25 commits into from
Jul 24, 2024

Conversation

zpcore
Copy link
Member

@zpcore zpcore commented Jul 12, 2024

Deprecate torch_xla.xla_model.xrt_world_size and use torch_xla.runtime.world_size instead.

Add the run_once decorator to function runtime.using_pjrt since we only need to run this once per process. This helps get rid of dynamo compilation issue with xm.all_reduce.

@zpcore zpcore marked this pull request as ready for review July 12, 2024 22:15
@zpcore zpcore requested a review from will-cromar July 12, 2024 22:15
Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use xr.world_size instead? I don't think we need to rename the function in xla_model

xla/torch_xla/runtime.py

Lines 148 to 152 in 5b8e8e0

def world_size() -> int:
"""Returns the total number of processes participating in the job."""
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
return 1
return global_device_count()

@zpcore
Copy link
Member Author

zpcore commented Jul 15, 2024

Can you use xr.world_size instead? I don't think we need to rename the function in xla_model

xla/torch_xla/runtime.py

Lines 148 to 152 in 5b8e8e0

def world_size() -> int:
"""Returns the total number of processes participating in the job."""
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
return 1
return global_device_count()

xrt_world_size uses _WORLD_SIZE besides runtime.world_size():

global _WORLD_SIZE
if _WORLD_SIZE is not None:
return _WORLD_SIZE
. There is a concern that removing _WORLD_SIZE will cause recompilation. @alanwaketan for insights.

@zpcore zpcore requested a review from alanwaketan July 15, 2024 19:49
@alanwaketan
Copy link
Collaborator

You will know if there is a recompilation from the test.

@zpcore
Copy link
Member Author

zpcore commented Jul 17, 2024

Hi @will-cromar , I tried functools.lru_cache and it crashes in multiprocess. I notice that if we use the functools.warps and assign attributes to the functions to be wrapped, it will cause crash. Probably lru_cache uses the func attributes in this case.

I create run_once and moved global var into runtime.py in order to deprecate xrt_world_size.

_ORDINAL = runtime.global_ordinal()


def run_once(func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this run_once for? It's a neat idea, but I see you opted for global variables for world size and ordinal

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the @run_once for using_pjrt, the test test_mp_replication will fail with build the dynamic graph for dynamo compile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why that is? Is it because now all calls to get e.g. world_size go through functions wrapped in requires_pjrt, which in turn actually is checking an env var (device_type and _maybe_select_default_device)? Whereas before the call would have been stopped by xm.world_size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check my reply here: #7679 (comment)

@@ -70,6 +101,7 @@ def device_type() -> Optional[str]:
return pjrt_device.split('_')[0] if pjrt_device else pjrt_device


@run_once
def using_pjrt() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, this function also needs to get deprecated since I assume this is always True

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to call __maybe_select_default_device(), this is the point why I call using_pjrt() only once.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: if run_once makes dynamo happy for using_pjrt, do you know why it does't work for world_size and global_ordinal?

_ORDINAL = runtime.global_ordinal()


def run_once(func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why that is? Is it because now all calls to get e.g. world_size go through functions wrapped in requires_pjrt, which in turn actually is checking an env var (device_type and _maybe_select_default_device)? Whereas before the call would have been stopped by xm.world_size.

@zpcore
Copy link
Member Author

zpcore commented Jul 23, 2024

General question: if run_once makes dynamo happy for using_pjrt, do you know why it does't work for world_size and global_ordinal?

Take unittest test/test_mp_replication.py as the example.

  • Why it work for using_pjrt:
    Without run_once, we will see this function get called in from requires_pjrt() wrapper when we call function like world_size():

    xla/torch_xla/runtime.py

    Lines 36 to 38 in 42aa7bd

    def _maybe_select_default_device():
    if xu.getenv_as(xenv.PJRT_SELECT_DEFAULT_DEVICE, str,
    '1') == '0' or xenv.PJRT_DEVICE in os.environ:

    , which checks for the env var and makes it dynamic. Adding the @run_once, the function has already been executed and got skipped when we do dynamo compile.

  • Why we need global var like _WORLD_SIZE:
    Without it, we will see error like:

  from user code:
   File "/home/piz/pytorch/xla/test/test_mp_replication.py", line 11, in all_reduce
    return xm.all_reduce(xm.REDUCE_SUM, tensor)
  File "/home/piz/pytorch/xla/torch_xla/core/xla_model.py", line 428, in all_reduce
    if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
  File "/home/piz/pytorch/xla/torch_xla/runtime.py", line 128, in wrapper
    return fn(*args, **kwargs)
  File "/home/piz/pytorch/xla/torch_xla/runtime.py", line 186, in world_size
    if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

The C API binding _xla_get_replication_devices_count will be called and failed the dynamo compile.

@zpcore zpcore enabled auto-merge (squash) July 24, 2024 22:21
Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Filed a follow-up bug to clean up use_pjrt and requires_pjrt and fix a concrete usability issue at #7730

@zpcore zpcore merged commit 7f8ef79 into master Jul 24, 2024
23 checks passed
@JackCaoG
Copy link
Collaborator

TPU CI failure seems relevant, can we fix forward or revert this pr?

@zpcore
Copy link
Member Author

zpcore commented Jul 25, 2024

I forgot to update the TPU CI test. Let me make a follow up PR now.

@@ -36,7 +36,7 @@ dist.init_process_group("xla", rank=rank, world_size=world_size)

```
new_rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
world_size = xr.world_size()
Copy link
Collaborator

@miladm miladm Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zpcore please add import torch_xla.runtime as xr to section 1. it feels this line comes from left field without the import in the documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xm.world_size() is also deprecated. They all point to the same thing. We should only use xr.world_size().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants