Skip to content

Commit

Permalink
Cli fixes and improvements (#25)
Browse files Browse the repository at this point in the history
* Revamp cli args (#45)

* Rachel/follow (#46)

* Add fine_tunes.follow. Add better error handling for disconnected streams

* return early

* fix an oops

* lint

* Nicer strings

* ensure end token is not applied to classification (#44)

* ensure end token is not applied to classification

* black

Co-authored-by: Boris Power <81998504+BorisPower@users.noreply.github.com>
  • Loading branch information
rachellim and BorisPower authored Jun 30, 2021
1 parent dc15660 commit d92502f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 47 deletions.
110 changes: 65 additions & 45 deletions openai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,34 +286,26 @@ def create(cls, args):
create_args["validation_file"] = cls._get_or_upload(
args.validation_file, args.check_if_files_exist
)
if args.model:
create_args["model"] = args.model
if args.n_epochs:
create_args["n_epochs"] = args.n_epochs
if args.batch_size:
create_args["batch_size"] = args.batch_size
if args.learning_rate_multiplier:
create_args["learning_rate_multiplier"] = args.learning_rate_multiplier
create_args["use_packing"] = args.use_packing
if args.prompt_loss_weight:
create_args["prompt_loss_weight"] = args.prompt_loss_weight
if args.compute_classification_metrics:
create_args[
"compute_classification_metrics"
] = args.compute_classification_metrics
if args.classification_n_classes:
create_args["classification_n_classes"] = args.classification_n_classes
if args.classification_positive_class:
create_args[
"classification_positive_class"
] = args.classification_positive_class
if args.classification_betas:
betas = [float(x) for x in args.classification_betas.split(",")]
create_args["classification_betas"] = betas

for hparam in (
"model",
"n_epochs",
"batch_size",
"learning_rate_multiplier",
"prompt_loss_weight",
"use_packing",
"compute_classification_metrics",
"classification_n_classes",
"classification_positive_class",
"classification_betas",
):
attr = getattr(args, hparam)
if attr is not None:
create_args[hparam] = attr

resp = openai.FineTune.create(**create_args)

if args.no_wait:
if args.no_follow:
print(resp)
return

Expand Down Expand Up @@ -345,20 +337,32 @@ def results(cls, args):

@classmethod
def events(cls, args):
if not args.stream:
resp = openai.FineTune.list_events(id=args.id) # type: ignore
print(resp)
return
if args.stream:
raise openai.error.OpenAIError(
message=(
"The --stream parameter is deprecated, use fine_tunes.follow "
"instead:\n\n"
" openai api fine_tunes.follow -i {id}\n".format(id=args.id)
),
)

resp = openai.FineTune.list_events(id=args.id) # type: ignore
print(resp)

@classmethod
def follow(cls, args):
cls._stream_events(args.id)

@classmethod
def _stream_events(cls, job_id):
def signal_handler(sig, frame):
status = openai.FineTune.retrieve(job_id).status
sys.stdout.write(
"\nStream interrupted. Job is still {status}. "
"\nStream interrupted. Job is still {status}.\n"
"To resume the stream, run:\n\n"
" openai api fine_tunes.follow -i {job_id}\n\n"
"To cancel your job, run:\n\n"
"openai api fine_tunes.cancel -i {job_id}\n".format(
" openai api fine_tunes.cancel -i {job_id}\n\n".format(
status=status, job_id=job_id
)
)
Expand All @@ -368,16 +372,24 @@ def signal_handler(sig, frame):

events = openai.FineTune.stream_events(job_id)
# TODO(rachel): Add a nifty spinner here.
for event in events:
sys.stdout.write(
"[%s] %s"
% (
datetime.datetime.fromtimestamp(event["created_at"]),
event["message"],
try:
for event in events:
sys.stdout.write(
"[%s] %s"
% (
datetime.datetime.fromtimestamp(event["created_at"]),
event["message"],
)
)
sys.stdout.write("\n")
sys.stdout.flush()
except Exception:
sys.stdout.write(
"\nStream interrupted (client disconnected).\n"
"To resume the stream, run:\n\n"
" openai api fine_tunes.follow -i {job_id}\n\n".format(job_id=job_id)
)
sys.stdout.write("\n")
sys.stdout.flush()
return

resp = openai.FineTune.retrieve(id=job_id)
status = resp["status"]
Expand Down Expand Up @@ -688,9 +700,9 @@ def help(args):
help="The model to start fine-tuning from",
)
sub.add_argument(
"--no_wait",
"--no_follow",
action="store_true",
help="If set, returns immediately after creating the job. Otherwise, waits for the job to complete.",
help="If set, returns immediately after creating the job. Otherwise, streams events and waits for the job to complete.",
)
sub.add_argument(
"--n_epochs",
Expand Down Expand Up @@ -727,7 +739,7 @@ def help(args):
dest="use_packing",
help="Disables the packing flag (see --use_packing for description)",
)
sub.set_defaults(use_packing=True)
sub.set_defaults(use_packing=None)
sub.add_argument(
"--prompt_loss_weight",
type=float,
Expand All @@ -741,6 +753,7 @@ def help(args):
help="If set, we calculate classification-specific metrics such as accuracy "
"and F-1 score using the validation set at the end of every epoch.",
)
sub.set_defaults(compute_classification_metrics=None)
sub.add_argument(
"--classification_n_classes",
type=int,
Expand All @@ -755,10 +768,11 @@ def help(args):
)
sub.add_argument(
"--classification_betas",
type=float,
nargs="+",
help="If this is provided, we calculate F-beta scores at the specified beta "
"values. The F-beta score is a generalization of F-1 score. This is only "
"used for binary classification. The expected format is a comma-separated "
"list - e.g. 1,1.5,2",
"used for binary classification.",
)
sub.set_defaults(func=FineTune.create)

Expand All @@ -772,15 +786,21 @@ def help(args):

sub = subparsers.add_parser("fine_tunes.events")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")

# TODO(rachel): Remove this in 1.0
sub.add_argument(
"-s",
"--stream",
action="store_true",
help="If set, events will be streamed until the job is done. Otherwise, "
help="[DEPRECATED] If set, events will be streamed until the job is done. Otherwise, "
"displays the event history to date.",
)
sub.set_defaults(func=FineTune.events)

sub = subparsers.add_parser("fine_tunes.follow")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTune.follow)

sub = subparsers.add_parser("fine_tunes.cancel")
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
sub.set_defaults(func=FineTune.cancel)
2 changes: 1 addition & 1 deletion openai/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def common_completion_suffix_validator(df):
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")

ft_type = infer_task_type(df)
if ft_type == "open-ended generation":
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")

def add_suffix(x, suffix):
Expand Down
2 changes: 1 addition & 1 deletion openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.9.2"
VERSION = "0.9.3"

0 comments on commit d92502f

Please sign in to comment.