Skip to content

reverted separate prefill and decode detokenize queue #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 19 additions & 49 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
of the generation loop at the relevant slot.
- Regardless, it performs a step.
- It takes the sampled tokens, and places them on a 'detokenizing_queue'.
7. Within the detokenizing thread (Prefill and Generate separately):
7. Within the detokenizing thread:
- Tokens are detokenized for every 'slot' in a given set of sampled tokens.
- When an end condition is met, the 'slot' integer is returned to the
respective generation queue.
Expand Down Expand Up @@ -220,8 +220,7 @@ class Driver:
# Stage 4
# This can be a list because we can pass it as an arg to generate and
# detokenize threads. It is a list of tokens to be detokenized.
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_slots: list[queue.Queue[int]] = []
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []

Expand Down Expand Up @@ -281,11 +280,11 @@ def __init__(
# one of the generate backlogs.
# Interleaved Mode: Max size is 1 to increase the HBM utilization
# during generate.
# Disaggregated Mode: Max size is 16 to allow for total 16 prefills to
# be enqueued or enqueued while 1 is being transferred.
# Disaggregated Mode: Max size is 4 to allow for 2 prefills to be enqueued
# while 1 transfer is enqueued while 1 is being transferred.
# TODO: Make queue size configurable.
self._transfer_backlogs = [
queue.Queue(1 if self._interleaved_mode else 16)
queue.Queue(1 if self._interleaved_mode else 4)
for i in range(len(self._prefill_engines))
]
if self._metrics_collector:
Expand Down Expand Up @@ -313,11 +312,10 @@ def __init__(
functools.partial(float, backlog.qsize())
)
# Stage 4
# After prefill and generation, ActiveRequests are placed on the
# detokenization backlog for tokens to be sent into each ActiveRequest's
# return channel.
# We have one of these per prefill / generate engine to simplify
# the logic keeping track of which generation engine to replace slots on.
# After generation, ActiveRequests are placed on the detokenization backlog
# for tokens to be sent into each ActiveRequest's return channel.
# We have one of these per generate engine to simplify the logic keeping
# track of which generation engine to replace slots on.
# This is a queue of either - tuple[int, ActiveRequest] which represents our
# active requests, or tuple[int, sample_tokens]. We combine these into one
# queue because it allows us to be somewhat clever with how we do
Expand All @@ -332,16 +330,7 @@ def __init__(
# the possibility of race conditions where a slot is made live before the
# tokens are ready and it receives tokens from a different sequence,
# or tokens detokenized before the relevant slot is live.

self._prefill_detokenize_backlogs = [
# No need to set maxsize, as transfer queue can
# provide the backpressure to the prefill workload
# (to avoid the overwhelming prefill).
queue.Queue()
for _ in self._prefill_engines
]

self._generate_detokenize_backlogs = [
self._detokenize_backlogs = [
# We don't let detokenization accumulate more than 8 steps to avoid
# synchronization issues.
queue.Queue(8)
Expand Down Expand Up @@ -397,25 +386,13 @@ def __init__(
)
for idx in range(len(self._generate_engines))
]
self.prefill_detokenize_threads = [
JetThread(
target=functools.partial(
self._detokenize_thread,
is_prefill=True,
idx=idx,
),
name=f"prefill_detokenize-{idx}",
)
for idx in range(len(self._prefill_engines))
]
self.generate_detokenize_threads = [
self.detokenize_threads = [
JetThread(
target=functools.partial(
self._detokenize_thread,
is_prefill=False,
idx=idx,
idx,
),
name=f"generate_detokenize-{idx}",
name=f"detokenize-{idx}",
)
for idx in range(len(self._generate_engines))
]
Expand All @@ -424,8 +401,7 @@ def __init__(
self._prefill_threads,
self._transfer_threads,
self._generate_threads,
self.prefill_detokenize_threads,
self.generate_detokenize_threads,
self.detokenize_threads,
)
)
self.live = True
Expand All @@ -444,8 +420,7 @@ def stop(self):
[self._prefill_backlog],
self._transfer_backlogs,
self._generate_backlogs.values(),
self._prefill_detokenize_backlogs,
self._generate_detokenize_backlogs,
self._detokenize_backlogs,
)
)

Expand Down Expand Up @@ -561,7 +536,7 @@ def _prefill_thread(self, idx: int):

# put first token to detokenize queue
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
my_detokenize_backlog = self._detokenize_backlogs[idx]
request.metadata.transfer_enqueue_time = time.perf_counter()
my_detokenize_backlog.put(
(first_token, request, request.metadata.prefill_dequeue_time),
Expand Down Expand Up @@ -657,7 +632,7 @@ def _generate_thread(self, idx: int):
generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]
my_generate_backlog = self._generate_backlogs[idx]
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]
my_detokenize_backlog = self._detokenize_backlogs[idx]

# Keep track of what step tokens were generated at.
generate_timestep = 0
Expand Down Expand Up @@ -787,17 +762,12 @@ def _generate_thread(self, idx: int):
)
time_of_last_generate = time.time()

def _detokenize_thread(self, is_prefill: bool, idx: int):
def _detokenize_thread(self, idx: int):
"""Detokenize sampled tokens and returns them to the user."""
# One of these per generate engine.
# For all filled my_slots, pop the sampled token onto the relevant
# requests return channel. If it done, place it back onto free slots.

if is_prefill:
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
else:
my_detokenize_backlog = self._generate_detokenize_backlogs[idx]

my_detokenize_backlog = self._detokenize_backlogs[idx]
my_generate_engine = self._generate_engines[idx]
my_slots = self._generate_slots[idx]

Expand Down