77
77
import dataclasses
78
78
import functools
79
79
import logging
80
+ import os
80
81
import queue
82
+ import signal
81
83
import threading
82
84
import time
85
+ import traceback
83
86
from typing import Any , Iterable , Optional , Union
84
87
85
88
import grpc
86
89
import jax
87
- import numpy as np
88
-
89
- from jetstream .engine import engine_api
90
- from jetstream .engine import token_utils
91
90
from jetstream .core .proto import jetstream_pb2
92
91
from jetstream .core .proto import jetstream_pb2_grpc
92
+ from jetstream .engine import engine_api
93
+ from jetstream .engine import token_utils
94
+ import numpy as np
93
95
94
96
95
97
@dataclasses .dataclass
96
98
class ActiveRequest :
97
99
"""Current state of the driver."""
100
+
98
101
#################### Information relevant for generation #####################
99
102
max_tokens : int
100
103
# [num_samples,] which corresponds to whether each sample is complete for the
@@ -129,6 +132,21 @@ def enqueue_tokens(self, generated_tokens: list[str]):
129
132
self .return_channel .put (generated_tokens )
130
133
131
134
135
+ class JetThread (threading .Thread ):
136
+ """Thread that kills the program if it fails.
137
+
138
+ If a driver thread goes down, we can't operate.
139
+ """
140
+
141
+ def run (self ):
142
+ try :
143
+ super ().run ()
144
+ except Exception as e : # pylint: disable=broad-exception-caught
145
+ print (f'Thread { self .name } encountered an error: { e } ' )
146
+ traceback .print_exc ()
147
+ os .kill (os .getpid (), signal .SIGKILL )
148
+
149
+
132
150
class Driver :
133
151
"""Drives the engines."""
134
152
@@ -213,40 +231,47 @@ def __init__(
213
231
214
232
# Construct a)
215
233
self ._generate_slots = [queue .Queue () for _ in generate_engines ]
216
- _ = [[self ._generate_slots [idx ].put (i )
217
- for i in range (engine .max_concurrent_decodes )]
218
- for idx , engine in enumerate (generate_engines )]
234
+ _ = [
235
+ [
236
+ self ._generate_slots [idx ].put (i )
237
+ for i in range (engine .max_concurrent_decodes )
238
+ ]
239
+ for idx , engine in enumerate (generate_engines )
240
+ ]
219
241
220
242
# Kick off all our threads
221
243
self ._prefill_threads = [
222
- threading .Thread (
223
- target = functools .partial (self ._prefill_thread , idx , engine )
244
+ JetThread (
245
+ target = functools .partial (self ._prefill_thread , idx , engine ),
246
+ name = f'prefill-{ idx } ' ,
224
247
)
225
248
for idx , engine in enumerate (self ._prefill_engines )
226
249
]
227
- self ._transfer_thread = threading . Thread (target = self ._transfer_thread )
250
+ self ._transfer_thread = JetThread (target = self ._transfer_thread )
228
251
self ._generate_threads = [
229
- threading . Thread (
252
+ JetThread (
230
253
target = functools .partial (
231
254
self ._generate_thread ,
232
255
idx ,
233
256
engine ,
234
257
self ._generate_slots [idx ],
235
258
self ._detokenize_backlogs [idx ],
236
- )
259
+ ),
260
+ name = f'generate-{ idx } ' ,
237
261
)
238
262
for idx , engine in enumerate (self ._generate_engines )
239
263
]
240
264
# Construct b)
241
265
self .detokenize_threads = [
242
- threading . Thread (
266
+ JetThread (
243
267
target = functools .partial (
244
268
self ._detokenize_thread ,
245
269
idx ,
246
270
engine ,
247
271
self ._generate_slots [idx ],
248
272
self ._detokenize_backlogs [idx ],
249
- )
273
+ ),
274
+ name = f'detokenize-{ idx } ' ,
250
275
)
251
276
for idx , engine in enumerate (self ._generate_engines )
252
277
]
@@ -268,8 +293,12 @@ def _load_cache_history(self, path: str) -> Union[None, Any]:
268
293
else :
269
294
return None
270
295
271
- def _prefill_thread (self , idx : int , prefill_engine : engine_api .Engine ,
272
- transfer_backpressure : int = 8 ):
296
+ def _prefill_thread (
297
+ self ,
298
+ idx : int ,
299
+ prefill_engine : engine_api .Engine ,
300
+ transfer_backpressure : int = 8 ,
301
+ ):
273
302
"""Thread which runs in the background performing prefills."""
274
303
logging .info ('---------Spinning up prefill thread %d.---------' , idx )
275
304
prefill_params = self ._prefill_params [idx ]
@@ -280,15 +309,23 @@ def _prefill_thread(self, idx: int, prefill_engine: engine_api.Engine,
280
309
while self .live :
281
310
# We don't want to keep lots of kv caches live in memory on the prefill
282
311
# slice that aren't about to be sent over to a generation slice.
283
- if ( self ._transfer_backlog .qsize () < transfer_backpressure ) :
312
+ if self ._transfer_backlog .qsize () < transfer_backpressure :
284
313
# Check if there is anything on the prefill backlog, pop if so.
285
314
try :
286
315
request = self ._prefill_backlog .get (block = True )
287
316
# TODO: Implement hot/cold cache for history.
288
317
history = self ._load_cache_history (request .history_path ) # pylint: disable = assignment-from-none
289
318
# Tokenize, and introduce a leading dimension
290
319
is_bos = not bool (request .history_path )
291
- logging .info ('Prefilling on prefill engine %d : "%s", prefill queue size, %d, is_bos: %s, history: %s' , idx , request .prefill_text , self ._prefill_backlog .qsize (), is_bos , request .history_path ) # pylint: disable = line-too-long
320
+ logging .info (
321
+ 'Prefilling on prefill engine %d : "%s", prefill queue size, %d,'
322
+ ' is_bos: %s, history: %s' ,
323
+ idx ,
324
+ request .prefill_text ,
325
+ self ._prefill_backlog .qsize (),
326
+ is_bos ,
327
+ request .history_path ,
328
+ ) # pylint: disable = line-too-long
292
329
padded_tokens , true_length = token_utils .tokenize_and_pad (
293
330
request .prefill_text ,
294
331
vocab ,
@@ -385,13 +422,14 @@ def _generate_thread(
385
422
new_request = self ._generate_backlogs [idx ].get ()
386
423
slot = my_slots .get ()
387
424
logging .info (
388
- 'Generate slice %d slot %d step %d,'
389
- ' generating for : "%s"' , idx , slot , generate_timestep ,
425
+ 'Generate slice %d slot %d step %d, generating for : "%s"' ,
426
+ idx ,
427
+ slot ,
428
+ generate_timestep ,
390
429
new_request .prefill_text ,
391
430
)
392
431
decode_state = generate_engine .insert (
393
- new_request .prefill_result , decode_state ,
394
- slot = slot
432
+ new_request .prefill_result , decode_state , slot = slot
395
433
)
396
434
new_request .generate_timestep_added = generate_timestep
397
435
new_request .complete = np .zeros (
@@ -412,7 +450,7 @@ def _generate_thread(
412
450
generate_timestep ,
413
451
my_slots .qsize (),
414
452
generate_engine .max_concurrent_decodes ,
415
- (time .time () - time_of_last_generate )* 10 ** 3 ,
453
+ (time .time () - time_of_last_generate ) * 10 ** 3 ,
416
454
)
417
455
time_of_last_generate = time .time ()
418
456
0 commit comments