Description
TL;DR - Proposal for an API to support launching expensive computations in Serve (e.g., model fine-tuning, long-running inference) using an asynchronous request API
Problem statement
With the rise of generative models, the Serve team has seen growing interest in supporting "expensive computations" in Serve. For example, users have asked to launch Stable Diffusion fine-tuning jobs and long-running inference tasks that run not for seconds, but for several minutes to an hour. These tasks are often too long to run as a stateless inference request, but too short to justify launching an entirely new Ray job / cluster.
As workarounds, users are often connecting other queueing systems to Ray, such as Celery. The purpose of this RFC is to gather feedback on APIs for handling such workloads natively in Serve without needing a full-blown queueing system.
Previous proposals
A previous RFC proposed using Ray Workflows as a wholesale replacement for queueing systems. This solution works, but is heavyweight and relies on Workflows, a relatively new library: #21161
Below are two alternate proposals with the aim to provide a simpler API.
Proposal 1 -- add async requests API to Serve
Add an "async_request" decorator to Serve deployments. For async decorated methods, Serve will generate queueing preamble/postamble logic and APIs to enable listing, resuming, and checking on the status of async requests. The API would look something as follows:
@serve.deployment
class FineTuningApp:
@serve.async_request(
max_queued=1000,
max_running=10,
priority=2,
idempotency_key="request_id")
def fine_tune_request(self, request_id: str, num_epochs: int, dataset: str) -> str:
# User-implemented checkpoint and recovery.
if checkpoint_exists(request_id):
model, epoch = restore_from_checkpoint(request_id)
else:
model, epoch = new_model(), 0
for i in range(epoch, num_epochs):
train_one_epoch(model, dataset)
save_checkpoint(model, i, request_id)
return model
Here are examples of generated API methods for managing async requests:
POST /app/fine_tune_request?request_id=12345&dataset=foo1&num_epochs=5
POST /app/fine_tune_request/cancel?request_id=12345
GET /app/fine_tune_request/status?request_id=12345
GET /app/fine_tune_request/result?request_id=12345
GET /app/fine_tune_request/list?filter_status=RUNNING&limit=1000
Fault tolerance: Serve would persist the queue of requests in its coordinator actor / persistent storage. When resuming from cluster failure, Serve can load and resume previous running async requests, which can resume from any checkpoints they have taken.
Pros:
- Simple extension of existing Serve handlers, which means requests can take advantage of existing Serve scheduling, autoscaling, observability, etc. logic
Cons:
- May not interact well with other Serve APIs like DeploymentGraphs
Proposal 2 -- create a simplified TaskQueue API backed by Ray Workflows
Instead of extending Serve's API, create a separate TaskQueue API that users can use to manually create a Serve handler implementing management methods. For example, the above example could be instead implemented as:
# Define the processing function separately from the handler.
def fine_tune_request(request_id: str, num_epochs: int, dataset: str) -> str:
# User-implemented checkpoint and recovery.
if checkpoint_exists(request_id):
model, epoch = restore_from_checkpoint(request_id)
else:
model, epoch = new_model(), 0
for i in range(epoch, num_epochs):
train_one_epoch(model, dataset)
save_checkpoint(model, i, request_id)
return model
@serve.deployment
class FineTuningApp:
def __init__(self):
TaskQueue.create_if_not_exists("my_queue")
TaskQueue.resume_all("my_queue")
def get_or_create_request(self, request_id: str, num_epochs: int, dataset: str):
TaskQueue.submit_task("my_queue", fine_tune_request, request_id, args=[request_id, num_epochs, dataset])
return "ok"
def get_status(self, request_id: str):
return TaskQueue.get_status("my_queue", request_id)
def get_result(self, request_id: str):
return TaskQueue.get_result("my_queue", request_id)
def cancel_task(self, request_id: str):
return TaskQueue.cancel("my_queue", request_id)
def list_tasks(self):
return TaskQueue.list("my_queue")
Fault tolerance: can be implemented in a similar way as proposal (1).
Pros:
- Serve API remains unchanged
Cons:
- More boilerplate / less clean story for users
- More complex autoscaling and resource allocation story, since both Serve and the TaskQueue library would be requesting resources from the Ray scheduler
- More new concepts and things to know