1515import asyncio
1616
1717# This is the model to load for workers
18- MODEL_PATH = "/models/vicuna-7b/"
19-
20-
18+ MODEL_PATH = "YOUR_MODEL_PATH"
2119"""
22201. Prepare a faked sequencegroup meta data
23212. Start a mocked AsyncLLMEngine, and modify its step_async function
24223. invoke the step_async function manually
23+ 4. this test tries to kick off the `model_execution` part for the
24+ model so that we can perform tests
2525"""
2626
27+
2728class UglyAsyncLLMEngine (LLMEngine ):
2829 """Extension of LLMEngine to add async methods."""
2930
3031 async def step_async (self ) -> List [RequestOutput ]:
31- sampling_para = SamplingParams (n = 2 , best_of = 5 , temperature = 0.8 , top_p = 0.95 , max_tokens = 7 )
32+ sampling_para = SamplingParams (n = 2 ,
33+ best_of = 5 ,
34+ temperature = 0.8 ,
35+ top_p = 0.95 ,
36+ max_tokens = 7 )
3237 seq_data = {}
3338 seq_data [0 ] = SequenceData (prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ])
3439 request_id = "cmpl-7bef75eaa4394a3d895b5508dd5f69f6"
3540
36- seq_group_meta_data = SequenceGroupMetadata (request_id = request_id , is_prompt = True , seq_data = seq_data , sampling_params = sampling_para , block_tables = {})
41+ seq_group_meta_data = SequenceGroupMetadata (
42+ request_id = request_id ,
43+ is_prompt = True ,
44+ seq_data = seq_data ,
45+ sampling_params = sampling_para ,
46+ block_tables = {})
3747 seq_group_meta_data_lists = [seq_group_meta_data ]
3848
3949 output = await self ._run_workers_async (
@@ -44,24 +54,66 @@ async def step_async(self) -> List[RequestOutput]:
4454 blocks_to_copy = {},
4555 finished_seqs = [],
4656 )
47- print (output )
4857
49- # TODO: change this to real one
50- return RequestOutput (request_id = request_id , prompt = "" , prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ], outputs = [], finished = False )
58+ # Co(gc): we cannot use the real one as it contains private methods that cannot be invoked
59+ return RequestOutput (request_id = request_id ,
60+ prompt = "" ,
61+ prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ],
62+ outputs = [],
63+ finished = False )
5164
5265 async def step_async_multiple (self ) -> List [RequestOutput ]:
66+ """
67+ Same but send two requests in a batch
68+ """
5369 seq_group_metadata_lists = []
54- request_id_0 = "cmpl-81e2b9767b5b47bca7e649482698d385"
55- seq_data_0 = {0 : SequenceData (prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ])}
56- sampling_params_0 = SamplingParams (n = 1 , best_of = 1 , presence_penalty = 0.0 , frequency_penalty = 0.0 , temperature = 0.0 , top_p = 1.0 , top_k = - 1 , use_beam_search = False , length_penalty = 1.0 , early_stopping = False , stop = [], ignore_eos = False , max_tokens = 7 , logprobs = None , skip_special_tokens = True )
57-
58- seq_group_metadata_lists .append (SequenceGroupMetadata (request_id_0 , True , seq_data_0 , sampling_params_0 , {}))
70+ request_id_0 = "cmpl-81e2b9767b5b47bca7e649482698d385"
71+ seq_data_0 = {
72+ 0 : SequenceData (prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ])
73+ }
74+ sampling_params_0 = SamplingParams (n = 1 ,
75+ best_of = 1 ,
76+ presence_penalty = 0.0 ,
77+ frequency_penalty = 0.0 ,
78+ temperature = 0.0 ,
79+ top_p = 1.0 ,
80+ top_k = - 1 ,
81+ use_beam_search = False ,
82+ length_penalty = 1.0 ,
83+ early_stopping = False ,
84+ stop = [],
85+ ignore_eos = False ,
86+ max_tokens = 7 ,
87+ logprobs = None ,
88+ skip_special_tokens = True )
89+
90+ seq_group_metadata_lists .append (
91+ SequenceGroupMetadata (request_id_0 , True , seq_data_0 ,
92+ sampling_params_0 , {}))
5993
6094 request_id_1 = "cmpl-81e2b9767b5b47bca7e649482698d385"
61- seq_data_1 = {1 : SequenceData (prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ])}
62- sampling_params_1 = SamplingParams (n = 1 , best_of = 1 , presence_penalty = 0.0 , frequency_penalty = 0.0 , temperature = 0.0 , top_p = 1.0 , top_k = - 1 , use_beam_search = False , length_penalty = 1.0 , early_stopping = False , stop = [], ignore_eos = False , max_tokens = 7 , logprobs = None , skip_special_tokens = True )
63-
64- seq_group_metadata_lists .append (SequenceGroupMetadata (request_id_1 , True , seq_data_1 , sampling_params_1 , {}))
95+ seq_data_1 = {
96+ 1 : SequenceData (prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ])
97+ }
98+ sampling_params_1 = SamplingParams (n = 1 ,
99+ best_of = 1 ,
100+ presence_penalty = 0.0 ,
101+ frequency_penalty = 0.0 ,
102+ temperature = 0.0 ,
103+ top_p = 1.0 ,
104+ top_k = - 1 ,
105+ use_beam_search = False ,
106+ length_penalty = 1.0 ,
107+ early_stopping = False ,
108+ stop = [],
109+ ignore_eos = False ,
110+ max_tokens = 7 ,
111+ logprobs = None ,
112+ skip_special_tokens = True )
113+
114+ seq_group_metadata_lists .append (
115+ SequenceGroupMetadata (request_id_1 , True , seq_data_1 ,
116+ sampling_params_1 , {}))
65117
66118 output = await self ._run_workers_async (
67119 "execute_model" ,
@@ -72,9 +124,11 @@ async def step_async_multiple(self) -> List[RequestOutput]:
72124 finished_seqs = [],
73125 )
74126
75- # TODO: change this to real one
76- return RequestOutput (request_id = request_id_0 , prompt = "" , prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ], outputs = [], finished = False )
77-
127+ return RequestOutput (request_id = request_id_0 ,
128+ prompt = "" ,
129+ prompt_token_ids = [1 , 3087 , 8970 , 338 , 263 ],
130+ outputs = [],
131+ finished = False )
78132
79133 async def _run_workers_async (
80134 self ,
@@ -106,13 +160,36 @@ async def _run_workers_async(
106160 assert output == other_output
107161 return output
108162
163+
109164setattr (AsyncLLMEngine , "_engine_class" , UglyAsyncLLMEngine )
110165
111166
112167@pytest .mark .asyncio
113168async def test_model_execution ():
114- # Let's build an engine_args
115- engine_args = AsyncEngineArgs (model = '/models/vicuna-7b/' , tokenizer = '/models/vicuna-7b/' , tokenizer_mode = 'auto' , trust_remote_code = False , download_dir = None , load_format = 'dummy' , dtype = 'auto' , seed = 0 , max_model_len = None , worker_use_ray = False , pipeline_parallel_size = 1 , tensor_parallel_size = 1 , block_size = 16 , swap_space = 16 , gpu_memory_utilization = 0.9 , max_num_batched_tokens = None , max_num_seqs = 256 , disable_log_stats = False , revision = None , tokenizer_revision = None , quantization = None , engine_use_ray = False , disable_log_requests = True , max_log_len = None )
169+ # Let's build an engine_args
170+ engine_args = AsyncEngineArgs (model = MODEL_PATH ,
171+ tokenizer = MODEL_PATH ,
172+ tokenizer_mode = 'auto' ,
173+ trust_remote_code = False ,
174+ download_dir = None ,
175+ dtype = 'auto' ,
176+ seed = 0 ,
177+ max_model_len = None ,
178+ worker_use_ray = False ,
179+ pipeline_parallel_size = 1 ,
180+ tensor_parallel_size = 1 ,
181+ block_size = 16 ,
182+ swap_space = 16 ,
183+ gpu_memory_utilization = 0.9 ,
184+ max_num_batched_tokens = None ,
185+ max_num_seqs = 256 ,
186+ disable_log_stats = False ,
187+ revision = None ,
188+ tokenizer_revision = None ,
189+ quantization = None ,
190+ engine_use_ray = False ,
191+ disable_log_requests = True ,
192+ max_log_len = None )
116193 # Start the engine
117194 engine = AsyncLLMEngine .from_engine_args (engine_args )
118195
@@ -121,7 +198,3 @@ async def test_model_execution():
121198 await engine .engine .step_async ()
122199 # Now let's try something difficult
123200 await engine .engine .step_async_multiple ()
124-
125-
126-
127-
0 commit comments