-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9b8acb4
commit 66184f7
Showing
38 changed files
with
139,027 additions
and
1,549 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
import argparse | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--seed_tasks_path", | ||
type=str, | ||
required=True, | ||
default="data/seed_tasks.jsonl", | ||
help="The path to the human written data. The keys of the dictionary should be `instruction`, `input` and `output`.", | ||
) | ||
parser.add_argument( | ||
"--output_data_path", | ||
type=str, | ||
required=True, | ||
default="data/output.jsonl", | ||
help="The path to output data file.", | ||
) | ||
parser.add_argument( | ||
"--num_instructions_to_generate", | ||
type=int, | ||
default=100, | ||
help="The number of instructions to generate.", | ||
) | ||
parser.add_argument( | ||
"--template_name", | ||
default="default", | ||
help="Name of the template to use for in-context learning.", | ||
) | ||
parser.add_argument( | ||
"--use_tgi", | ||
action="store_true", # default False | ||
help="Whether or not to use text-generation inference. In this case you should have your HF_TOKEN and API_URL stored as env variables.", | ||
) | ||
parser.add_argument( | ||
"--keep_programming", | ||
action="store_true", # default False | ||
help="Whether or not to keep programming tasks.", | ||
) | ||
parser.add_argument( | ||
"--format", choices=[2, 3], type=int, help="biprompt or triprompt." | ||
) | ||
parser.add_argument( | ||
"--model_name_or_path", | ||
type=str, | ||
default="bigcode/starcoder", | ||
help="The name or path of the model to use.", | ||
) | ||
parser.add_argument( | ||
"--num_prompt_instructions", | ||
type=int, | ||
default=8, | ||
help="The number of instructions to use in the prompt.", | ||
) | ||
parser.add_argument( | ||
"--request_batch_size", | ||
type=int, | ||
default=4, | ||
help="The number of requests to send to the model at a time.", | ||
) | ||
parser.add_argument( | ||
"--num_prompt_synthetic_instructions", | ||
type=int, | ||
default=2, | ||
help="The number of synthetic (model-generated instructions) to use in the prompt.", | ||
) | ||
parser.add_argument( | ||
"--max_new_tokens", | ||
default=4096, | ||
type=int, | ||
help="The max_new_tokens parameter of the generate function. It is the maximum number of tokens to generate", | ||
) | ||
parser.add_argument( | ||
"--temperature", | ||
default=0.2, | ||
type=float, | ||
help="The temperature of the generation.", | ||
) | ||
parser.add_argument( | ||
"--top_p", | ||
default=0.9, | ||
type=float, | ||
help="The `top_p` parameter for the generation.", | ||
) | ||
parser.add_argument( | ||
"--stop_words", | ||
default=["\n20", "20.", "20 ."], | ||
nargs="+", | ||
help="The `stop_words` that are considered during the generation.", | ||
) | ||
parser.add_argument( | ||
"--num_beams", | ||
default=1, | ||
type=int, | ||
help="The beam size used during the generation.", | ||
) | ||
parser.add_argument( | ||
"--repetition_penalty", | ||
type=float, | ||
default=1.2, | ||
help="The repetition penalty parameter to use for the generation.", | ||
) | ||
parser.add_argument( | ||
"--threshold", | ||
type=float, | ||
default=0.7, | ||
help="The similarity threshold for filtering.", | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=42, | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def parse_args_for_post_processing(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--seed_tasks_path", | ||
type=str, | ||
required=True, | ||
default="data/seed_tasks.jsonl", | ||
help="The path to the human written data. The keys of the dictionary should be `instruction`, `input` and `output`.", | ||
) | ||
parser.add_argument( | ||
"--input_data_path", | ||
type=str, | ||
required=True, | ||
default="data/output.jsonl", | ||
help="The path to data we want to post-process.", | ||
) | ||
parser.add_argument( | ||
"--output_data_path", | ||
type=str, | ||
required=True, | ||
default="data/output.jsonl", | ||
help="Path where we want to store the processed data.", | ||
) | ||
parser.add_argument( | ||
"--template_name", | ||
default="default", | ||
help="Name of the template to use for in-context learning.", | ||
) | ||
parser.add_argument( | ||
"--model_name_or_path", | ||
type=str, | ||
default="bigcode/starcoder", | ||
help="The name or path of the model to use.", | ||
) | ||
parser.add_argument( | ||
"--num_prompt_instructions", | ||
type=int, | ||
default=8, | ||
help="The number of instructions to use in the prompt.", | ||
) | ||
parser.add_argument( | ||
"--num_trials", | ||
type=int, | ||
default=8, | ||
help="The number of trials.", | ||
) | ||
parser.add_argument( | ||
"--max_new_tokens", | ||
default=4096, | ||
type=int, | ||
help="The max_new_tokens parameter of the generate function. It is the maximum number of tokens to generate", | ||
) | ||
parser.add_argument( | ||
"--temperature", | ||
default=0.2, | ||
type=float, | ||
help="The temperature of the generation.", | ||
) | ||
parser.add_argument( | ||
"--top_p", | ||
default=0.9, | ||
type=float, | ||
help="The `top_p` parameter for the generation.", | ||
) | ||
parser.add_argument( | ||
"--stop_words", | ||
default=["\n20", "20.", "20 ."], | ||
nargs="+", | ||
help="The `stop_words` that are considered during the generation.", | ||
) | ||
parser.add_argument( | ||
"--num_beams", | ||
default=1, | ||
type=int, | ||
help="The beam size used during the generation.", | ||
) | ||
parser.add_argument( | ||
"--repetition_penalty", | ||
type=float, | ||
default=1.2, | ||
help="The repetition penalty parameter to use for the generation.", | ||
) | ||
parser.add_argument( | ||
"--threshold", | ||
type=float, | ||
default=0.7, | ||
help="The similarity threshold for filtering.", | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=42, | ||
) | ||
return parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class Template: | ||
""" | ||
Defines the prompting format used to generate instructions as well as test cases. We advocate for 3 tokens, here is an example | ||
Code : | ||
{code} | ||
Instruction : | ||
{instruction} | ||
Test cases : | ||
{test case 1} | ||
{test case 2} | ||
{test case 3} | ||
""" | ||
|
||
code_token: str = "Code:" | ||
instruction_token: str = "Instruction:" | ||
cases_token: str = "Test cases:" | ||
|
||
def get_triprompt(self, example) -> str: | ||
""" | ||
takes as input a dictionary, i.e. a seed example with a code, an instruction and seed tasks. | ||
code : | ||
def maximum(arr): | ||
return max(arr) | ||
instruction : | ||
write a function which takes as input a list arr and return its maximum. | ||
test cases : | ||
assert maximum([1, 2, 3]) == 3 | ||
assert maximum([1]) == 1 | ||
""" | ||
prompt = f"{self.code_token}\n{example['code']}\n\n{self.instruction_token}\n{example['instruction']}\n\n{self.cases_token}\n" | ||
return prompt.strip() | ||
|
||
def get_biprompt(self, example) -> str: | ||
""" | ||
takes as input a dictionary, i.e. a seed example with a code, an instruction and seed tasks. | ||
code : | ||
def maximum(arr): | ||
return max(arr) | ||
test cases : | ||
assert maximum([1, 2, 3]) == 3 | ||
assert maximum([1]) == 1 | ||
""" | ||
prompt = f"{self.code_token}\n{example['code']}\n\n{self.cases_token}\n" | ||
for case in example["cases"]: | ||
prompt += f"{case}\n" | ||
return prompt.strip() | ||
|
||
def copy(self): | ||
return Template( | ||
code_token=self.code_token, | ||
instruction_token=self.instruction_token, | ||
cases_token=self.cases_token | ||
) | ||
|
||
default_template = Template() | ||
|
||
SUPPORTED_TEMPLATES = { | ||
"default": default_template, | ||
} | ||
|
||
|
||
def get_dialogue_template(template: str) -> Template: | ||
if template not in SUPPORTED_TEMPLATES.keys(): | ||
raise ValueError(f"Template {template} is not supported!") | ||
return SUPPORTED_TEMPLATES[template].copy() |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.