Skip to content

Add num_batch_threads and max_enqueued_batches to OBM batching configs. #1868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions export/orbax/export/obm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class BatchOptions:
largest value in the list. Otherwise, this must be provided.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
num_batch_threads: Number of scheduling threads for processing batches of
work. Determines the number of batches processed in parallel. This should
be roughly in line with the number of TPU cores available.
max_enqueued_batches: Maximum number of batches enqueued for processing
before requests are failed fast. Default is 250.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
all batch sizes no larger than `max_batch_size` are allowed. Otherwise,
supplies a list of batch sizes. The entries must increase monotonically.
Expand All @@ -49,6 +54,8 @@ class BatchOptions:
max_batch_size: int | None = None
batch_timeout_micros: int = 0
allowed_batch_sizes: Sequence[int] | None = None
num_batch_threads: int = 1
max_enqueued_batches: int = 250
disable_large_batch_splitting: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -118,3 +125,15 @@ def __post_init__(self):
"`batch_timeout_micros` must be non-negative. Got:"
f" {self.batch_timeout_micros}"
)

if self.num_batch_threads <= 0:
raise ValueError(
"`num_batch_threads` must be at least 1. Got:"
f" {self.num_batch_threads}"
)

if self.max_enqueued_batches <= 0:
raise ValueError(
"`max_enqueued_batches` must be at least 1. Got:"
f" {self.max_enqueued_batches}"
)
9 changes: 9 additions & 0 deletions export/orbax/export/oex_orchestration.proto
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,13 @@ message BatchOptions {
// If true, each input task is put into one batch as a whole for processing.
// More padding will be needed.
bool disable_large_batch_splitting = 5;

// Number of scheduling threads for processing batches of work. Determines
// the number of batches processed in parallel. This should be roughly in line
// with the number of TPU cores available.
int32 num_batch_threads = 6;

// Maximum number of batches enqueued for processing before requests are
// failed fast.
int32 max_enqueued_batches = 7;
}