11from typing import List , Optional , Set , Hashable
22import time
3+ import importlib .util
4+ import sys
35
46import torch
57import torch .nn as nn
3840
3941LLAMA_7B_VOCAB_SIZE = 32000
4042
41- from vllm .scratch import ScratchAPI
42- from vllm .scratch_env import (SCRATCH_TMP_DIR , SCRATCH_WEIGHTS_PREFIX ,
43+ from vllm .scratch_env import (SCRATCH_EXECUTABLE_PATH , SCRATCH_TMP_DIR , SCRATCH_WEIGHTS_PREFIX ,
4344 SCRATCH_WEIGHTS_BUCKET_NAME )
4445
4546# SANG-TODO WORKS?
4647MODEL_PARAMS_PATH = "/home/ray/default/weights"
4748
4849
50+ def import_scratch (path : Path ):
51+ SCRATCH_MODULE_NAME = "scratch"
52+ logger .info (f"Importing scratch module from { path } " )
53+ spec = importlib .util .spec_from_file_location (SCRATCH_MODULE_NAME , path .resolve ())
54+ scratch = importlib .util .module_from_spec (spec )
55+ sys .modules [SCRATCH_MODULE_NAME ] = scratch
56+ spec .loader .exec_module (scratch )
57+ return scratch
58+
59+
4960class ScratchSession :
5061
5162 def __init__ (self , scratch_session_id : int ):
@@ -54,7 +65,7 @@ def __init__(self, scratch_session_id: int):
5465
5566class ScratchLRUCache (LRUCache [ScratchSession ]):
5667
57- def __init__ (self , capacity : int , scratch_api : ScratchAPI ):
68+ def __init__ (self , capacity : int , scratch_api ):
5869 self ._scratch_api = scratch_api
5970 super ().__init__ (capacity )
6071
@@ -77,7 +88,7 @@ class ScratchSessionManager:
7788 information to model runner in a few weeks.
7889 """
7990
80- def __init__ (self , scratch_api : ScratchAPI , max_num_seqs : int ):
91+ def __init__ (self , scratch_api , max_num_seqs : int ):
8192 # ScratchAPI used to create/delete sessions.
8293 self ._scratch_api = scratch_api
8394 # Set capacity to max_num_seqs * 2 so that old sequences are
@@ -134,7 +145,7 @@ def __init__(
134145 self .pin_memory = is_pin_memory_available ()
135146
136147 # Lazily initialized.
137- self .scratch : ScratchAPI
148+ self .scratch : " ScratchAPI" # type: ignore
138149 # Scratch only returns embedding. We need to multiply it to lm_head
139150 # to get the final logits, and that happens in vLLM. In order to
140151 # do that, we create a torch module with lm_head weights loaded.
@@ -155,8 +166,10 @@ def _verify_scratch_config(self):
155166 "Vision model not supported" )
156167 assert self .kv_cache_dtype == "auto" , (
157168 "Currently, Scratch doesn't use kv cache." )
158- assert "llama-2" in self .model_config .model .lower (), (
159- "Only Llama 7B is supported." )
169+ # SANG-TODO Support only llama 2 and 3.
170+ assert ("llama-2" in self .model_config .model .lower ()
171+ or "llama-3" in self .model_config .model .lower ()), (
172+ "Only Llama 2 7B or llama 3 8B is supported." )
160173 assert self .lora_manager is None , ("lora is not supported." )
161174 assert self .model_config .enforce_eager is True , (
162175 "cuda graph is not needed for Scratch." )
@@ -171,7 +184,12 @@ def load_model(self) -> None:
171184 weights_dir = tmp_dir / "parameters"
172185 weights_dir .mkdir (exist_ok = True )
173186 # TODO(sang): Need to obtain this programmatically.
174- download_dir = weights_dir / "ll27b-s1-cuda-f16-fullopt"
187+ # download_dir = weights_dir / "ll27b-s1-cuda-f16-fullopt"
188+ scratch_mod = import_scratch (Path (SCRATCH_EXECUTABLE_PATH ))
189+ base_dir = str (weights_dir .resolve ())
190+ self .scratch = scratch_mod .ScratchAPI (base_dir )
191+ scratch_subdir = self .scratch .get_param_subdir ()
192+ download_dir = weights_dir / scratch_subdir
175193 download_dir .mkdir (exist_ok = True )
176194 download_dir_path = str (download_dir .absolute ())
177195 self .load_config .download_dir = str (weights_dir .absolute ())
@@ -190,7 +208,6 @@ def load_model(self) -> None:
190208 scheduler_config = self .scheduler_config ,
191209 cache_config = self .cache_config ,
192210 )
193- self .scratch = ScratchAPI (str (weights_dir .absolute ()))
194211 self .scratch .start ()
195212 self ._scratch_session_manager = ScratchSessionManager (
196213 self .scratch , self .scheduler_config .max_num_seqs )
@@ -223,7 +240,8 @@ def _download_scratch_weights(self, prefix: str, target_dir: str,
223240 dirs .append (k )
224241 next_token = results .get ('NextContinuationToken' )
225242 # Assume there's no subdirectories.
226- assert len (dirs ) == 1
243+ dirs = {p .rsplit ("/" , 1 )[0 ] for p in files }
244+ assert len (dirs ) == 1 , dirs
227245
228246 # NOTE(sang): Versioning is not supported now. We assume the
229247 # weights are always the same.
@@ -285,8 +303,8 @@ def execute_model(
285303 self .device ,
286304 self .pin_memory )
287305 return self ._execute_and_vllm_sample (prefill_groups , decode_groups ,
288- input_tokens , session_ids ,
289- parent_ids , sampling_metadata )
306+ input_tokens , session_ids ,
307+ parent_ids , sampling_metadata )
290308 # return self._execute_and_scratch_sample(
291309 # prefill_groups, decode_groups, input_tokens, session_ids, parent_ids)
292310
@@ -327,7 +345,7 @@ def _execute_and_vllm_sample(
327345 input_tokens_tensor = torch .tensor (input_tokens [i ],
328346 device = "cuda" ,
329347 dtype = torch .int )
330- print (f"SANG-TODO { input_tokens_tensor = } " )
348+ # print(f"SANG-TODO {input_tokens_tensor=}")
331349 assert input_tokens_tensor .is_contiguous ()
332350 # print(f"SANG-TODO {input_tokens_tensor.shape=}")
333351
@@ -338,7 +356,7 @@ def _execute_and_vllm_sample(
338356 hidden_states_end_index = (len_prefix_before_this + len (input_tokens [i ])) * self .model_config .get_hidden_size ()
339357 # print(f"SANG-TODO {hidden_states_start_index=} {hidden_states_end_index=}")
340358 # print(f"SANG-TODO {hidden_states.shape=}")
341- print (f"SANG-TODO { hidden_states [hidden_states_start_index : hidden_states_end_index ].shape = } " )
359+ # print(f"SANG-TODO {hidden_states[hidden_states_start_index: hidden_states_end_index].shape=}")
342360 assert hidden_states [hidden_states_start_index : hidden_states_end_index ].is_contiguous ()
343361 self .scratch .prefill (
344362 session_id ,
@@ -363,9 +381,9 @@ def _execute_and_vllm_sample(
363381 hidden_states .data_ptr (),
364382 )
365383
366- print (
367- f"SANG-TODO forward takes { (time .time () - s )* 1000 } ms. Batch size: { len (session_ids )= } is_prefill: { len (prefill_groups ) > 0 } "
368- )
384+ # print(
385+ # f"SANG-TODO forward takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=} is_prefill: {len(prefill_groups) > 0}"
386+ # )
369387 # print(hidden_states)
370388 # print(f"SANG-TODO {hidden_states.shape=}")
371389 # Post process Scratch embeddings.
@@ -375,16 +393,16 @@ def _execute_and_vllm_sample(
375393 # is this expected?
376394 hidden_states = hidden_states .view (- 1 ,
377395 self .model_config .get_hidden_size ())
378- if len (prefill_groups ) > 0 :
379- print (f"SANG-TODO before norm { hidden_states = } " )
380- print (f"SANG-TODO { hidden_states .shape = } " )
396+ # if len(prefill_groups) > 0:
397+ # print(f"SANG-TODO before norm {hidden_states=}")
398+ # print(f"SANG-TODO {hidden_states.shape=}")
381399 # Scratch doesn't apply rms norm in its output, so we should do it ourselves.
382400 # Residual is set to None because it is already added from Scratch output.
383401 hidden_states = self .model .norm (hidden_states , None )
384- if len (prefill_groups ) > 0 :
385- print (f"SANG-TODO norm weights: { self .model .norm .weight = } " )
386- print (f"SANG-TODO { hidden_states .shape = } " )
387- print (f"SANG-TODO after norm { hidden_states = } " )
402+ # if len(prefill_groups) > 0:
403+ # print(f"SANG-TODO norm weights: {self.model.norm.weight=}")
404+ # print(f"SANG-TODO {hidden_states.shape=}")
405+ # print(f"SANG-TODO after norm {hidden_states=}")
388406 # print(f"{hidden_states.shape=}")
389407
390408 # SANG-TODO remove it. Hack. It will work once scrath returns embedding of all tokens correctly.
@@ -401,14 +419,14 @@ def _execute_and_vllm_sample(
401419 logits = logits ,
402420 sampling_metadata = sampling_metadata ,
403421 )
404- if len (prefill_groups ) > 0 :
405- print (
406- f"SANG-TODO prefill takes { (time .time () - s )* 1000 } ms. Batch size: { len (session_ids )= } "
407- )
408- else :
409- print (
410- f"SANG-TODO decode takes { (time .time () - s )* 1000 } ms. Batch size: { len (session_ids )= } "
411- )
422+ # if len(prefill_groups) > 0:
423+ # print(
424+ # f"SANG-TODO prefill takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
425+ # )
426+ # else:
427+ # print(
428+ # f"SANG-TODO decode takes {(time.time() - s)* 1000} ms. Batch size: {len(session_ids)=}"
429+ # )
412430 # print(output)
413431 return output
414432
@@ -443,7 +461,7 @@ def _execute_and_scratch_sample(
443461 batch_size ,
444462 tokens_out .data_ptr (),
445463 )
446- print (f"SANG-TODO token: { tokens_out } " )
464+ # print(f"SANG-TODO token: {tokens_out}")
447465
448466 result_tokens = tokens_out .tolist ()
449467 outputs = []
@@ -462,7 +480,7 @@ def _execute_and_scratch_sample(
462480 )
463481 )
464482 output = SamplerOutput (outputs = outputs )
465- print (output )
483+ # print(output)
466484 return output
467485
468486 @torch .inference_mode ()
0 commit comments