Skip to content

Commit 78a7612

Browse files
authored
Refine thread in orchestrator (#2)
1 parent 78a0b06 commit 78a7612

File tree

1 file changed

+61
-23
lines changed

1 file changed

+61
-23
lines changed

jetstream/core/orchestrator.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,27 @@
7777
import dataclasses
7878
import functools
7979
import logging
80+
import os
8081
import queue
82+
import signal
8183
import threading
8284
import time
85+
import traceback
8386
from typing import Any, Iterable, Optional, Union
8487

8588
import grpc
8689
import jax
87-
import numpy as np
88-
89-
from jetstream.engine import engine_api
90-
from jetstream.engine import token_utils
9190
from jetstream.core.proto import jetstream_pb2
9291
from jetstream.core.proto import jetstream_pb2_grpc
92+
from jetstream.engine import engine_api
93+
from jetstream.engine import token_utils
94+
import numpy as np
9395

9496

9597
@dataclasses.dataclass
9698
class ActiveRequest:
9799
"""Current state of the driver."""
100+
98101
#################### Information relevant for generation #####################
99102
max_tokens: int
100103
# [num_samples,] which corresponds to whether each sample is complete for the
@@ -129,6 +132,21 @@ def enqueue_tokens(self, generated_tokens: list[str]):
129132
self.return_channel.put(generated_tokens)
130133

131134

135+
class JetThread(threading.Thread):
136+
"""Thread that kills the program if it fails.
137+
138+
If a driver thread goes down, we can't operate.
139+
"""
140+
141+
def run(self):
142+
try:
143+
super().run()
144+
except Exception as e: # pylint: disable=broad-exception-caught
145+
print(f'Thread {self.name} encountered an error: {e}')
146+
traceback.print_exc()
147+
os.kill(os.getpid(), signal.SIGKILL)
148+
149+
132150
class Driver:
133151
"""Drives the engines."""
134152

@@ -213,40 +231,47 @@ def __init__(
213231

214232
# Construct a)
215233
self._generate_slots = [queue.Queue() for _ in generate_engines]
216-
_ = [[self._generate_slots[idx].put(i)
217-
for i in range(engine.max_concurrent_decodes)]
218-
for idx, engine in enumerate(generate_engines)]
234+
_ = [
235+
[
236+
self._generate_slots[idx].put(i)
237+
for i in range(engine.max_concurrent_decodes)
238+
]
239+
for idx, engine in enumerate(generate_engines)
240+
]
219241

220242
# Kick off all our threads
221243
self._prefill_threads = [
222-
threading.Thread(
223-
target=functools.partial(self._prefill_thread, idx, engine)
244+
JetThread(
245+
target=functools.partial(self._prefill_thread, idx, engine),
246+
name=f'prefill-{idx}',
224247
)
225248
for idx, engine in enumerate(self._prefill_engines)
226249
]
227-
self._transfer_thread = threading.Thread(target=self._transfer_thread)
250+
self._transfer_thread = JetThread(target=self._transfer_thread)
228251
self._generate_threads = [
229-
threading.Thread(
252+
JetThread(
230253
target=functools.partial(
231254
self._generate_thread,
232255
idx,
233256
engine,
234257
self._generate_slots[idx],
235258
self._detokenize_backlogs[idx],
236-
)
259+
),
260+
name=f'generate-{idx}',
237261
)
238262
for idx, engine in enumerate(self._generate_engines)
239263
]
240264
# Construct b)
241265
self.detokenize_threads = [
242-
threading.Thread(
266+
JetThread(
243267
target=functools.partial(
244268
self._detokenize_thread,
245269
idx,
246270
engine,
247271
self._generate_slots[idx],
248272
self._detokenize_backlogs[idx],
249-
)
273+
),
274+
name=f'detokenize-{idx}',
250275
)
251276
for idx, engine in enumerate(self._generate_engines)
252277
]
@@ -268,8 +293,12 @@ def _load_cache_history(self, path: str) -> Union[None, Any]:
268293
else:
269294
return None
270295

271-
def _prefill_thread(self, idx: int, prefill_engine: engine_api.Engine,
272-
transfer_backpressure: int = 8):
296+
def _prefill_thread(
297+
self,
298+
idx: int,
299+
prefill_engine: engine_api.Engine,
300+
transfer_backpressure: int = 8,
301+
):
273302
"""Thread which runs in the background performing prefills."""
274303
logging.info('---------Spinning up prefill thread %d.---------', idx)
275304
prefill_params = self._prefill_params[idx]
@@ -280,15 +309,23 @@ def _prefill_thread(self, idx: int, prefill_engine: engine_api.Engine,
280309
while self.live:
281310
# We don't want to keep lots of kv caches live in memory on the prefill
282311
# slice that aren't about to be sent over to a generation slice.
283-
if (self._transfer_backlog.qsize() < transfer_backpressure):
312+
if self._transfer_backlog.qsize() < transfer_backpressure:
284313
# Check if there is anything on the prefill backlog, pop if so.
285314
try:
286315
request = self._prefill_backlog.get(block=True)
287316
# TODO: Implement hot/cold cache for history.
288317
history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none
289318
# Tokenize, and introduce a leading dimension
290319
is_bos = not bool(request.history_path)
291-
logging.info('Prefilling on prefill engine %d : "%s", prefill queue size, %d, is_bos: %s, history: %s', idx, request.prefill_text, self._prefill_backlog.qsize(), is_bos, request.history_path) # pylint: disable = line-too-long
320+
logging.info(
321+
'Prefilling on prefill engine %d : "%s", prefill queue size, %d,'
322+
' is_bos: %s, history: %s',
323+
idx,
324+
request.prefill_text,
325+
self._prefill_backlog.qsize(),
326+
is_bos,
327+
request.history_path,
328+
) # pylint: disable = line-too-long
292329
padded_tokens, true_length = token_utils.tokenize_and_pad(
293330
request.prefill_text,
294331
vocab,
@@ -385,13 +422,14 @@ def _generate_thread(
385422
new_request = self._generate_backlogs[idx].get()
386423
slot = my_slots.get()
387424
logging.info(
388-
'Generate slice %d slot %d step %d,'
389-
' generating for : "%s"', idx, slot, generate_timestep,
425+
'Generate slice %d slot %d step %d, generating for : "%s"',
426+
idx,
427+
slot,
428+
generate_timestep,
390429
new_request.prefill_text,
391430
)
392431
decode_state = generate_engine.insert(
393-
new_request.prefill_result, decode_state,
394-
slot=slot
432+
new_request.prefill_result, decode_state, slot=slot
395433
)
396434
new_request.generate_timestep_added = generate_timestep
397435
new_request.complete = np.zeros(
@@ -412,7 +450,7 @@ def _generate_thread(
412450
generate_timestep,
413451
my_slots.qsize(),
414452
generate_engine.max_concurrent_decodes,
415-
(time.time() - time_of_last_generate)*10**3,
453+
(time.time() - time_of_last_generate) * 10**3,
416454
)
417455
time_of_last_generate = time.time()
418456

0 commit comments

Comments
 (0)