Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MS-712] Support calculation for percentage slider #135

Open
wants to merge 5 commits into
base: dev_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion results-modules/benchmarking-result.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def _generate_metadata(self, results_args: ResultArguments) -> dict:
"recipes": results_args.params.get("recipes"),
"cookbooks": results_args.params.get("cookbooks"),
"endpoints": results_args.params.get("endpoints"),
"num_of_prompts": results_args.params.get("num_of_prompts"),
"prompt_selection_percentage": results_args.params.get(
"prompt_selection_percentage"
),
"random_seed": results_args.params.get("random_seed"),
"system_prompt": results_args.params.get("system_prompt"),
}
Expand Down
36 changes: 25 additions & 11 deletions runners-modules/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ class Benchmarking:
prompt_index,prompt,target,predicted_results,duration,random_seed,system_prompt)
VALUES(?,?,?,?,?,?,?,?,?,?,?,?)
"""

sql_read_runner_cache_record = """
SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=? AND prompt_template_id=? AND prompt=?
SELECT * from runner_cache_table WHERE connection_id=? AND recipe_id=?
AND dataset_id=? AND prompt_template_id=? AND prompt=?
"""
BATCH_SIZE = 10
QUEUE_SIZE = 10
Expand Down Expand Up @@ -90,10 +92,21 @@ async def generate(
# Get required arguments from runner_args
self.cookbooks = self.runner_args.get("cookbooks", None)
self.recipes = self.runner_args.get("recipes", None)
self.num_of_prompts = self.runner_args.get("num_of_prompts", 0)
self.prompt_selection_percentage = self.runner_args.get(
"prompt_selection_percentage", 100
)
self.random_seed = self.runner_args.get("random_seed", 0)
self.system_prompt = self.runner_args.get("system_prompt", "")

# Perform validation on prompt_selection_percentage
if (
self.prompt_selection_percentage < 1
or self.prompt_selection_percentage > 100
):
raise RuntimeError(
"The 'prompt_selection_percentage' argument must be between 1 - 100."
)

# ------------------------------------------------------------------------------
# Part 0: Load common instances
# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -230,7 +243,7 @@ async def generate(
"recipes": self.recipes,
"cookbooks": self.cookbooks,
"endpoints": self.endpoints,
"num_of_prompts": self.num_of_prompts,
"prompt_selection_percentage": self.prompt_selection_percentage,
"random_seed": self.random_seed,
"system_prompt": self.system_prompt,
},
Expand Down Expand Up @@ -706,7 +719,7 @@ async def _get_dataset_prompts(
"""
Asynchronously retrieves prompts from a dataset based on the specified dataset ID.

This method determines the total number of prompts in the dataset and generates a list of prompt indices.
This method calculates the total number of prompts in the dataset and generates a list of prompt indices.
If a specific number of prompts is requested (num_of_prompts), it will randomly select that many prompts
using the provided random seed. Otherwise, it will retrieve all prompts. Each prompt is then fetched and
yielded along with its index.
Expand All @@ -717,14 +730,14 @@ async def _get_dataset_prompts(
Yields:
tuple[int, dict]: A tuple containing the index of the prompt and the prompt data itself.
"""
# Get dataset arguments
# Retrieve dataset arguments
ds_args = Dataset.read(ds_id)

# Generate a list of prompt indices based on num_of_prompts and random_seed
if (
self.num_of_prompts == 0
or self.num_of_prompts > ds_args.num_of_dataset_prompts
):
# Generate a list of prompt indices based on prompt_selection_percentage and random_seed
self.num_of_prompts = int(
(self.prompt_selection_percentage / 100) * ds_args.num_of_dataset_prompts
)
if self.num_of_prompts == ds_args.num_of_dataset_prompts:
prompt_indices = range(ds_args.num_of_dataset_prompts)
else:
random.seed(self.random_seed)
Expand All @@ -735,7 +748,7 @@ async def _get_dataset_prompts(
f"[Benchmarking] Dataset {ds_id}, using {len(prompt_indices)} of {ds_args.num_of_dataset_prompts} prompts."
)

# Use for loop to iterate over the async generator
# Iterate over the dataset examples and yield prompts based on the generated indices
prompts_gen_index = 0
for prompts_data in ds_args.examples:
if prompts_gen_index in prompt_indices:
Expand Down Expand Up @@ -865,6 +878,7 @@ async def _process_single_prompt(
(
new_prompt_info.conn_id,
new_prompt_info.rec_id,
new_prompt_info.ds_id,
new_prompt_info.pt_id,
new_prompt_info.connector_prompt.prompt,
),
Expand Down
Loading