Skip to content

Commit

Permalink
push fix for MS-753
Browse files Browse the repository at this point in the history
  • Loading branch information
imda-lionelteo committed Nov 13, 2024
1 parent fa9d787 commit acfa30e
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions runners-modules/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Benchmarking:
"""

sql_read_runner_cache_record = """
SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=?
SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=?
AND dataset_id=? AND prompt_template_id=? AND prompt=?
"""
BATCH_SIZE = 10
Expand Down Expand Up @@ -733,17 +733,24 @@ async def _get_dataset_prompts(
# Retrieve dataset arguments
ds_args = Dataset.read(ds_id)

# Generate a list of prompt indices based on prompt_selection_percentage and random_seed
self.num_of_prompts = int(
(self.prompt_selection_percentage / 100) * ds_args.num_of_dataset_prompts
)
if self.num_of_prompts == ds_args.num_of_dataset_prompts:
prompt_indices = range(ds_args.num_of_dataset_prompts)
if ds_args.num_of_dataset_prompts == 0:
prompt_indices = []
else:
random.seed(self.random_seed)
prompt_indices = random.sample(
range(ds_args.num_of_dataset_prompts), self.num_of_prompts
# Generate a list of prompt indices based on prompt_selection_percentage and random_seed
self.num_of_prompts = max(
1,
int(
(self.prompt_selection_percentage / 100)
* ds_args.num_of_dataset_prompts
),
)
if self.num_of_prompts == ds_args.num_of_dataset_prompts:
prompt_indices = range(ds_args.num_of_dataset_prompts)
else:
random.seed(self.random_seed)
prompt_indices = random.sample(
range(ds_args.num_of_dataset_prompts), self.num_of_prompts
)
logger.debug(
f"[Benchmarking] Dataset {ds_id}, using {len(prompt_indices)} of {ds_args.num_of_dataset_prompts} prompts."
)
Expand Down

0 comments on commit acfa30e

Please sign in to comment.