diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bbe712223a530..996a92d2a8b3d 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -89,8 +89,6 @@ def sample_sharegpt_requests( tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int, None]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -117,7 +115,7 @@ def sample_sharegpt_requests( prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: + if prompt_len < 4 or (fixed_output_len is None and output_len < 4): # Prune too short sequences. continue if prompt_len > 1024 or prompt_len + output_len > 2048: @@ -228,10 +226,11 @@ def sample_hf_requests( prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: + if fixed_output_len is None and (prompt_len < 4 or output_len < 4): # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: + if fixed_output_len is None and \ + (prompt_len > 1024 or prompt_len + output_len > 2048): # Prune too long sequences. continue