Skip to content

Pathways: reuse setup_spmd for pathways init #1248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,13 @@ def setup():
logging.info("LIBTPU_INIT_ARGS='%s'", os.environ["LIBTPU_INIT_ARGS"])

with _init_context():
if FLAGS.jax_backend == "proxy":
# AXLearn assumes rbg PRNG implementation and restore from checkpoint
# will fail on pathways if this isn't set. This is due shape of [4]
# being hardcoded here:
# https://github.com/apple/axlearn/blob/8bb4421e62c815ef9f1ba3679c3277b8bbc6a449/axlearn/common/trainer.py#L330
jax.config.update("jax_default_prng_impl", "rbg")

# pylint: disable-next=import-error,import-outside-toplevel
import pathwaysutils # pytype: disable=import-error

pathwaysutils.initialize()
else:
setup_spmd(
distributed_coordinator=FLAGS.distributed_coordinator,
num_processes=FLAGS.num_processes,
process_id=FLAGS.process_id,
jax_backend=FLAGS.jax_backend,
initialization_timeout=FLAGS.initialization_timeout,
)
setup_spmd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose setup_spmd isn't the most appropriate name anymore given pathways, but this seems OK for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setup_controller? but happy to stick with setup_spmd for now too. I'm afraid of breaking others by renaming setup functions like these. Not sure if people have custom launch.py functions.

distributed_coordinator=FLAGS.distributed_coordinator,
num_processes=FLAGS.num_processes,
process_id=FLAGS.process_id,
jax_backend=FLAGS.jax_backend,
initialization_timeout=FLAGS.initialization_timeout,
)

if FLAGS.jax_profiler_port is not None:
# Start jax.profiler for Tensorboard and profiling in open source.
Expand Down
39 changes: 24 additions & 15 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,37 @@ def setup(
"""Sets up the JAX environment for SPMD.

Args:
jax_backend: The distributed backend, which can be "cpu", "gpu", or "tpu".
distributed_coordinator: The distributed coordinator address (in the form of <host>:<port>).
Needed only for `jax_backend != "tpu"` and `num_processes > 1`. Otherwise, the
coordinator will be configured automatically when `num_processes` and `process_id` are
provided.
num_processes: The number of processes. Needed only if distributed initialization is desired
for `jax_backend != "tpu"`.
process_id: The process ID (the process rank). Needed only if distributed initialization is
desired for `jax_backend != "tpu"`.
initialization_timeout: The jax distributed initialization timeout in seconds. If None, uses
jax default.
jax_backend: The distributed backend. Can be "cpu", "gpu", "tpu", or "proxy".
distributed_coordinator: The distributed coordinator address (e.g., "<host>:<port>").
If jax_backend is "tpu", this may be automatically inferred by JAX.
If jax_backend is "proxy", this is ignored.
num_processes: The number of processes.
If jax_backend is "tpu", this may be automatically inferred by JAX.
If jax_backend is "proxy", this is ignored.
process_id: The process ID (the process rank).
If jax_backend is "tpu", this may be automatically inferred by JAX.
If jax_backend is "proxy", this is ignored.
initialization_timeout: The jax distributed initialization timeout in seconds.
If None, uses jax default.
If jax_backend is "proxy", this is ignored.

Raises:
ValueError: If any of the following conditions are met:
* distributed_coordinator, num_processes, or process_id are not None when
jax_backend is "tpu";
* one of num_processes or process_id is None when jax_backend is not "tpu";
* distributed_coordinator is None when jax_backend is not "tpu" and num_processes > 1.
* `jax_backend` not in ("tpu", "proxy") and (`num_processes` is None or `process_id` is
None).
* `jax_backend` not in ("tpu", "proxy"), `num_processes` > 1, and
`distributed_coordinator` is None.
"""
# Use a GSPMD-friendly PRNG implementation.
jax.config.update("jax_default_prng_impl", "rbg")

if jax_backend == "proxy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring for jax_backend?

Also clarify the expected inputs for the other args in this case? Should they always be None like in the TPU case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the docstring to explicitely call out that the other args will be ignored. They don't have to be None though.

# pylint: disable-next=import-error,import-outside-toplevel
import pathwaysutils # pytype: disable=import-error

pathwaysutils.initialize()
return

global _jax_distributed_initialized # pylint: disable=global-statement
if not _jax_distributed_initialized:
init_kwargs = {}
Expand Down