|
38 | 38 | from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
39 | 39 |
|
40 | 40 | import numpy as np
|
| 41 | +import pandas as pd |
41 | 42 | from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
42 | 43 | RequestFuncOutput)
|
43 | 44 | from datasets import load_dataset
|
@@ -131,6 +132,35 @@ def sample_sharegpt_requests(
|
131 | 132 | return filtered_dataset
|
132 | 133 |
|
133 | 134 |
|
| 135 | +def sample_burstgpt_requests( |
| 136 | + dataset_path: str, |
| 137 | + num_requests: int, |
| 138 | + random_seed: int, |
| 139 | + tokenizer: PreTrainedTokenizerBase, |
| 140 | +) -> List[Tuple[str, int, int, None]]: |
| 141 | + df = pd.read_csv(dataset_path) |
| 142 | + gpt4_df = df[df["Model"] == "GPT-4"] |
| 143 | + # Remove the failed requests (i.e., response length is 0) |
| 144 | + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] |
| 145 | + # Randomly sample num_requests from the dataset |
| 146 | + if num_requests <= len(gpt4_df): |
| 147 | + gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed) |
| 148 | + else: |
| 149 | + gpt4_df = gpt4_df.sample(n=num_requests, |
| 150 | + random_state=random_seed, |
| 151 | + replace=True) |
| 152 | + # Convert the dataframe to a list of tuples |
| 153 | + dataset = gpt4_df.values.tolist() |
| 154 | + input_requests = [] |
| 155 | + for i in range(num_requests): |
| 156 | + input_len = int(dataset[i][2]) |
| 157 | + output_len = int(dataset[i][3]) |
| 158 | + prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size |
| 159 | + for j in range(input_len)]) |
| 160 | + input_requests.append((prompt, input_len, output_len, None)) |
| 161 | + return input_requests |
| 162 | + |
| 163 | + |
134 | 164 | def sample_sonnet_requests(
|
135 | 165 | dataset_path: str,
|
136 | 166 | num_requests: int,
|
@@ -830,6 +860,14 @@ def main(args: argparse.Namespace):
|
830 | 860 | fixed_output_len=args.sharegpt_output_len,
|
831 | 861 | )
|
832 | 862 |
|
| 863 | + elif args.dataset_name == "burstgpt": |
| 864 | + input_requests = sample_burstgpt_requests( |
| 865 | + dataset_path=args.dataset_path, |
| 866 | + num_requests=args.num_prompts, |
| 867 | + random_seed=args.seed, |
| 868 | + tokenizer=tokenizer, |
| 869 | + ) |
| 870 | + |
833 | 871 | elif args.dataset_name == "sonnet":
|
834 | 872 | # Do not format the prompt, pass to message directly
|
835 | 873 | if args.backend == "openai-chat":
|
@@ -995,7 +1033,7 @@ def main(args: argparse.Namespace):
|
995 | 1033 | "--dataset-name",
|
996 | 1034 | type=str,
|
997 | 1035 | default="sharegpt",
|
998 |
| - choices=["sharegpt", "sonnet", "random", "hf"], |
| 1036 | + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], |
999 | 1037 | help="Name of the dataset to benchmark on.",
|
1000 | 1038 | )
|
1001 | 1039 | parser.add_argument("--dataset-path",
|
|
0 commit comments