-
Notifications
You must be signed in to change notification settings - Fork 369
Pathways: reuse setup_spmd for pathways init #1248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
19617d5
dd4425f
935db46
f7d548d
c7783cc
83a875b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,28 +22,37 @@ def setup( | |
"""Sets up the JAX environment for SPMD. | ||
|
||
Args: | ||
jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu". | ||
distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>). | ||
Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the | ||
coordinator will be configured automatically when `num_processes` and `process_id` are | ||
provided. | ||
num_processes: The number of processes. Needed only if distributed initialization is desired | ||
for `jax_backend != "tpu"`. | ||
process_id: The process ID (the process rank). Needed only if distributed initialization is | ||
desired for `jax_backend != "tpu"`. | ||
initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses | ||
jax default. | ||
jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy". | ||
distributed_coordinator: The distributed coordinator address (e.g., "<host>:<port>"). | ||
If jax_backend is "tpu", this may be automatically inferred by JAX. | ||
If jax_backend is "proxy", this is ignored. | ||
num_processes: The number of processes. | ||
If jax_backend is "tpu", this may be automatically inferred by JAX. | ||
If jax_backend is "proxy", this is ignored. | ||
process_id: The process ID (the process rank). | ||
If jax_backend is "tpu", this may be automatically inferred by JAX. | ||
If jax_backend is "proxy", this is ignored. | ||
initialization_timeout: The jax distributed initialization timeout in seconds. | ||
If None, uses jax default. | ||
If jax_backend is "proxy", this is ignored. | ||
|
||
Raises: | ||
ValueError: If any of the following conditions are met: | ||
* distributed_coordinator, num_processes, or process_id are not None when | ||
jax_backend is "tpu"; | ||
* one of num_processes or process_id is None when jax_backend is not "tpu"; | ||
* distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1. | ||
* `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is | ||
None). | ||
* `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and | ||
`distributed_coordinator` is None. | ||
""" | ||
# Use a GSPMD-friendly PRNG implementation. | ||
jax.config.update("jax_default_prng_impl", "rbg") | ||
|
||
if jax_backend == "proxy": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the docstring for Also clarify the expected inputs for the other args in this case? Should they always be None like in the TPU case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the docstring to explicitely call out that the other args will be ignored. They don't have to be None though. |
||
# pylint: disable-next=import-error,import-outside-toplevel | ||
import pathwaysutils # pytype: disable=import-error | ||
|
||
pathwaysutils.initialize() | ||
return | ||
|
||
global _jax_distributed_initialized # pylint: disable=global-statement | ||
if not _jax_distributed_initialized: | ||
init_kwargs = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose
setup_spmd
isn't the most appropriate name anymore given pathways, but this seems OK for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setup_controller? but happy to stick with setup_spmd for now too. I'm afraid of breaking others by renaming setup functions like these. Not sure if people have custom launch.py functions.