From 838dcda162e465b2e84f5b33434e55c1df8f6942 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 3 Nov 2024 03:52:38 -0800 Subject: [PATCH] Simplify tokenizer manager (#1899) --- docs/references/custom_chat_template.md | 13 +++- python/sglang/srt/managers/io_struct.py | 16 +++-- .../sglang/srt/managers/tokenizer_manager.py | 70 +++++++------------ 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/docs/references/custom_chat_template.md b/docs/references/custom_chat_template.md index 64b33a0a42..2803abc012 100644 --- a/docs/references/custom_chat_template.md +++ b/docs/references/custom_chat_template.md @@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server: python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2 ``` -If the chat template you are looking for is missing, you are welcome to contribute it. -Meanwhile, you can also temporarily register your chat template as follows: +If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file. + +## JSON Format +You can load the JSON format, which is defined by `conversation.py`. ```json { @@ -28,4 +30,11 @@ Meanwhile, you can also temporarily register your chat template as follows: ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json +``` + +## Jinja Format +You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja ``` \ No newline at end of file diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f29a7d3bce..df873035e9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -114,8 +114,7 @@ def post_init(self): if self.parallel_sample_num == 1: num = self.batch_size else: - # FIXME support cascade inference - # first bs samples are used for caching the prefix for parallel sampling + # The first bs samples are used for caching the prefix for parallel sampling num = self.batch_size + self.parallel_sample_num * self.batch_size if self.image_data is None: @@ -196,6 +195,9 @@ class EmbeddingReqInput: # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None + # Whether it is a single request or a batch request + is_single: bool = True + def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput: sampling_params: SamplingParams +RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]] + + @dataclass class RewardReqInput: - # The input prompt in the chat format. It can be a single prompt or a batch of prompts. - conv: Union[List[List[Dict]], List[Dict]] + # The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string. + conv: RewardReqConv # The request id. rid: Optional[Union[List[str], str]] = None # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None + # Whether it is a single request or a batch request + is_single: bool = True + def post_init(self): self.is_single = isinstance(self.conv[0], dict) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 785b18165d..c7d0bc7837 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -51,6 +51,7 @@ GetMemPoolSizeReq, GetMemPoolSizeReqOutput, ProfileReq, + RewardReqConv, RewardReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -89,6 +90,7 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, ): + # Parse args self.server_args = server_args # Init inter-process communication @@ -114,6 +116,7 @@ def __init__( self.context_len = server_args.context_length or get_context_length( self.hf_config ) + # Create image processor placeholder self.image_processor = get_dummy_image_processor() @@ -165,7 +168,8 @@ async def generate_request( if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( - "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model." + "This model does not appear to be an embedding model by default. " + "Please add `--is-embedding` when launching the server or try another model." ) obj.post_init() @@ -187,12 +191,8 @@ async def _send_single_request( if not is_cache_for_prefill: # The normal case with a single prompt if index is None: rid = obj.rid - if hasattr(obj, "conv"): - # reward model - conv = obj.conv - input_text = self.tokenizer.apply_chat_template( - conv, tokenize=False - ) + if isinstance(obj, RewardReqInput): + input_text = self._apply_chat_template(obj.conv) input_ids = self.tokenizer.encode(input_text) elif obj.input_ids is None: input_text = obj.text @@ -213,12 +213,8 @@ async def _send_single_request( top_logprobs_num = obj.top_logprobs_num else: rid = obj.rid[index] - if hasattr(obj, "conv"): - # reward model - conv = obj.conv[index] - input_text = self.tokenizer.apply_chat_template( - conv, tokenize=False - ) + if isinstance(obj, RewardReqInput): + input_text = self._apply_chat_template(obj.conv[input_id_index]) input_ids = self.tokenizer.encode(input_text) elif obj.input_ids is None: input_text = obj.text[input_id_index] @@ -349,8 +345,9 @@ async def _handle_single_request( async for response in self._wait_for_response(state, obj, rid, request): yield response else: - assert self.is_generation - await self._wait_for_cache_prefill_response(state, obj, rid, request) + await state.event.wait() + assert state.finished + del self.rid_to_state[rid] yield input_ids async def _handle_batch_request( @@ -456,6 +453,15 @@ def _get_sampling_params(self, sampling_params_data: dict): sampling_params.verify() return sampling_params + def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]: + if isinstance(conv, str): + return conv + elif isinstance(conv, list): + if isinstance(conv[0], str): + return conv + else: + return self.tokenizer.apply_chat_template(conv, tokenize=False) + async def _wait_for_response( self, state: ReqState, @@ -491,12 +497,11 @@ async def _wait_for_response( out["index"] = response_index - # Log requests - if self.server_args.log_requests and state.finished: - logger.info(f"in={obj}, out={out}") - state.out_list = [] if state.finished: + # Log requests + if self.server_args.log_requests: + logger.info(f"in={obj}, out={out}") del self.rid_to_state[rid] yield out break @@ -504,27 +509,6 @@ async def _wait_for_response( state.event.clear() yield out - async def _wait_for_cache_prefill_response( - self, - state: ReqState, - obj: GenerateReqInput, - rid: str, - request: Optional[fastapi.Request] = None, - ): - while True: - try: - await asyncio.wait_for(state.event.wait(), timeout=4) - break - except asyncio.TimeoutError: - if request is not None and await request.is_disconnected(): - for rid in obj.rid: - self.abort_request(rid) - raise ValueError(f"Abort request {rid}") - continue - - assert state.finished - del self.rid_to_state[rid] - def flush_cache(self): req = FlushCacheReq() self.send_to_scheduler.send_pyobj(req) @@ -553,6 +537,7 @@ async def get_memory_pool_size(self): self.send_to_scheduler.send_pyobj(req) self.mem_pool_size = asyncio.Future() + # FIXME: Each request should have its own future instead of using `self.mem_pool_size`. if self.server_args.dp_size == 1: res = await self.mem_pool_size return res.size @@ -638,7 +623,7 @@ async def sigterm_watchdog(self): while True: remain_num_req = len(self.rid_to_state) logger.info( - f"gracefully exiting... remaining number of requests {remain_num_req}" + f"Gracefully exiting... remaining number of requests {remain_num_req}" ) if remain_num_req > 0: await asyncio.sleep(5) @@ -695,7 +680,6 @@ async def handle_loop(self): "token_ids": recv_obj.output_ids[i], "meta_info": recv_obj.meta_info[i], } - else: assert isinstance(recv_obj, BatchEmbeddingOut) out_dict = { @@ -747,7 +731,7 @@ def detokenize_logprob_tokens( token_texts = self.tokenizer.batch_decode(token_ids) return [ (logprob, token_id, token_text) - for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) + for (logprob, token_id), token_text in zip(token_logprobs, token_texts) ] def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):