Skip to content

Conversation

lukebaumann
Copy link

@lukebaumann lukebaumann commented Aug 22, 2025

  • Added the changes to the jobset for elastic training to enable elasticity.
  • Added changes to launch_trainer so that the pause_resume decorator is used.
  • Set logging.raiseExceptions=True so that DATA_LOSS errors that occur in debug/info/other log calls are raise exceptions immediately.

As written, this will use Pause-Resume elasticity if Pathways is enabled.

This should be merged after #1335

@lukebaumann lukebaumann requested review from a team as code owners August 22, 2025 23:40
Added the changes to the jobset for elastic training to enable elasticity.
Added changes to launch_trainer so that the pause_resume decorator is used.
Set logging.raiseExceptions=True so that DATA_LOSS errors that occur in debug/info/other log calls are raise exceptions immediately.
# For elasticity, we want the slices to be able to restart many times.
# There is no way to set this to be unlimited so we set the backoffLimit
# very high.
backoffLimit *= 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only be done when elasticity is enabled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to restrict it (but I think it does make sense to restrict it).

_PATHWAYS_BACK_OFF_LIMIT = 32 so once either a slice has been restarted 32 times or the workload is restarted 32 times, GKE will fail due to the backoff limit.

Why it is not necessary:
For non-elastic workloads, a worker fails, gets restarted, the RM will kill all of the slices, the workload tries to access data at some point, the workload gets a DATA_LOSS exception that is not caught, the workload exits, the JobSet restarts.

What happens today:
Exactly the same as above except GKE will only fail due to the backoff limit after the workload is restarted 32 times (the slice may restart more than 32 times).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern is that if there is a user code error, it will take very long for the job to eventually report failure. Which may cause user confusion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This backoff limit increase is for the Pathways worker containers only. If there is a user code error, the JAX container will fail with the existing backoff limit. This allows worker containers to fail more than the JAX container. If a JAX container is connected to a worker container that fails, it will also fail unless elasticity is turned on.

# For elasticity, we want the slices to be able to restart many times.
# There is no way to set this to be unlimited so we set the backoffLimit
# very high.
backoffLimit *= 1000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern is that if there is a user code error, it will take very long for the job to eventually report failure. Which may cause user confusion.

from pathwaysutils.elastic import manager
elastic_manager = manager.Manager()
max_retries = 5
timeout = 10 * 60 # ten minutes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this timeout impact? How does it affect the multi-host inference case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This timeout argument is passed to wait_for_slices and is how long the JAX workload will wait for all slices to be ready after one (or more) of them fail before raising a TimeoutError. See wait_for_slices for more details.

The multi-host inference case should not enable elasticity or use elastic_manager.pause_resume. Instead, it should rely on LeaderWorkerSet for resiliency mechanisms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants