Skip to content

Commit

Permalink
minor fixes to tools prepare_data validators (#47) (#26)
Browse files Browse the repository at this point in the history
* ensure that only a single whitespace is prepended. Ensure the message regarding the prompt separator is displayed only if a prompt separator exists.

* change pandas contains to not use regex, which can trip if the common_suffix is actually a regex

Co-authored-by: Boris Power <81998504+BorisPower@users.noreply.github.com>
  • Loading branch information
rachellim and BorisPower authored Jul 12, 2021
1 parent d92502f commit fc1d9db
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
49 changes: 33 additions & 16 deletions openai/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def common_prompt_suffix_validator(df):
if suffix_option == " ->":
if df.prompt.str.contains("\n").any():
continue
if df.prompt.str.contains(suffix_option).any():
if df.prompt.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
Expand All @@ -202,7 +202,11 @@ def add_suffix(x, suffix):
)
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix).any():
if (
df.prompt.str[: -len(common_suffix)]
.str.contains(common_suffix, regex=False)
.any()
):
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"

else:
Expand Down Expand Up @@ -271,11 +275,15 @@ def common_completion_prefix_validator(df):
MAX_PREFIX_LEN = 5

common_prefix = get_common_xfix(df.completion, xfix="prefix")
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
if len(common_prefix) < MAX_PREFIX_LEN:
return Remediation(name="common_prefix")

def remove_common_prefix(x, prefix):
def remove_common_prefix(x, prefix, ws_prefix):
x["completion"] = x["completion"].str[len(prefix) :]
if ws_prefix:
# keep the single whitespace as prefix
x["completion"] = " " + x["completion"]
return x

if (df.completion == common_prefix).all():
Expand All @@ -286,7 +294,7 @@ def remove_common_prefix(x, prefix):
optional_msg = f"Remove prefix `{common_prefix}` from all completions"

def optional_fn(x):
return remove_common_prefix(x, common_prefix)
return remove_common_prefix(x, common_prefix, ws_prefix)

return Remediation(
name="common_completion_prefix",
Expand All @@ -305,6 +313,15 @@ def common_completion_suffix_validator(df):
optional_msg = None
optional_fn = None

ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")

common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)

# Find a suffix which is not contained within the completion otherwise
suggested_suffix = " [END]"
suffix_options = [
Expand All @@ -319,33 +336,28 @@ def common_completion_suffix_validator(df):
"%%%",
]
for suffix_option in suffix_options:
if df.completion.str.contains(suffix_option).any():
if df.completion.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")

ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")

def add_suffix(x, suffix):
x["completion"] += suffix
return x

common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)

if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = (
f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
)
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix).any():
if (
df.completion.str[: -len(common_suffix)]
.str.contains(common_suffix, regex=False)
.any()
):
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"

else:
Expand Down Expand Up @@ -617,8 +629,13 @@ def write_out_file(df, fname, any_remediations):
# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(outfnames))
valid_string = f' -v "{outfnames[1]}"' if split else ""
separator_reminder = (
""
if len(common_prompt_suffix_new_line_handled) == 0
else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
)
sys.stdout.write(
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\n{separator_reminder}{optional_ending_string}\n'
)
else:
sys.stdout.write("Aborting... did not write the file\n")
Expand Down
2 changes: 1 addition & 1 deletion openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.9.3"
VERSION = "0.9.4"

0 comments on commit fc1d9db

Please sign in to comment.