Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"from the ShareGPT dataset.",
)

blazedit_group = parser.add_argument_group("blazedit dataset options")
blazedit_group.add_argument(
"--blazedit-min-distance",
type=float,
default=0.0,
help=
"Minimum distance for blazedit dataset. Min: 0, Max: 1.0",
)
blazedit_group.add_argument(
"--blazedit-max-distance",
type=float,
default=1.0,
help=
"Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
)

random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
"--random-input-len",
Expand Down Expand Up @@ -1317,6 +1333,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
elif args.dataset_name == "hf":
# all following datasets are implemented from the
# HuggingFaceDataset base class
hf_kwargs = {}
if (
args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
Expand Down Expand Up @@ -1360,6 +1377,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
):
dataset_class = ASRDataset
args.hf_split = "train"
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
dataset_class = BlazeditDataset
args.hf_split = "train"
hf_kwargs = {
"min_distance": args.blazedit_min_distance,
"max_distance": args.blazedit_max_distance,
}
elif (
args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
Expand Down Expand Up @@ -1399,6 +1423,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
tokenizer=tokenizer,
output_len=args.hf_output_len,
request_id_prefix=args.request_id_prefix,
**hf_kwargs
)

else:
Expand Down Expand Up @@ -2012,6 +2037,94 @@ def sample(
return sampled_requests


# -----------------------------------------------------------------------------
# Blazedit Dataset Implementation
# -----------------------------------------------------------------------------


class BlazeditDataset(HuggingFaceDataset):
"""
Blazedit Dataset.
https://github.com/ise-uiuc/blazedit

5k char version: vdaita/edit_5k_char
10k char version: vdaita/edit_10k_char
""" # noqa: E501

# 5k char version will have output as ~5k chars
# 10k char version will have output as ~10k chars
# Assuming 3 char per token, 10k chars will be 3333 tokens
# We set default to 4000 to be safe
DEFAULT_OUTPUT_LEN = 4000
SUPPORTED_DATASET_PATHS = {
"vdaita/edit_5k_char",
"vdaita/edit_10k_char",
}

def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
request_id_prefix: str = "",
min_distance: float = 0.0,
max_distance: float = 1.0,
**kwargs,
) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []

for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
code = item["code"]
change_request = item["change_request"]
norm_distance = item["norm_distance"]

# compare the levenshtein distance normalized by code length
if norm_distance < min_distance or norm_distance > max_distance:
continue

# template copied from
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
instruction = f"""Given a code file, please apply the change requests and generate the new file.

Original file:
```python
{code}
```

Change request:
{change_request}

Please generate the new code file in the "New file" section below.""" # noqa: E501

# apply template
prompt = tokenizer.apply_chat_template(
[{
"role": "user",
"content": instruction
}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer(prompt).input_ids)

sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
request_id=request_id_prefix + str(i),
))
self.maybe_oversample_requests(sampled_requests, num_requests,
request_id_prefix)

return sampled_requests


# -----------------------------------------------------------------------------
# AIMO Dataset Implementation
# -----------------------------------------------------------------------------
Expand Down