Skip to content

Commit

Permalink
folder creation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flaviussn committed Jan 29, 2020
1 parent 489d4f7 commit 36256f1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
11 changes: 8 additions & 3 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def __init__(
if args:
self.args.update(args)

self.tokenizer = tokenizer_class.from_pretrained(model_name, do_lower_case=self.args["do_lower_case"], **kwargs)
self.tokenizer = tokenizer_class.from_pretrained(model_name,
do_lower_case=self.args["do_lower_case"],
**kwargs)

self.args["model_name"] = model_name
self.args["model_type"] = model_type
Expand Down Expand Up @@ -408,7 +410,10 @@ def train(
):
# Only evaluate when single GPU otherwise metrics may not average well
results, _, _ = self.eval_model(
eval_df, verbose=verbose and args["evaluate_during_training_verbose"], silent=True, **kwargs
eval_df,
verbose=verbose and args["evaluate_during_training_verbose"],
silent=True,
**kwargs
)
for key, value in results.items():
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
Expand Down Expand Up @@ -679,7 +684,7 @@ def load_and_cache_examples(
output_mode = "classification"

if not os.path.isdir(self.args["cache_dir"]):
os.mkdir(self.args["cache_dir"])
os.makedirs(self.args["cache_dir"])

mode = "dev" if evaluate else "train"
cached_features_file = os.path.join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def load_and_cache_examples(self, examples, evaluate=False, no_cache=False, mult
args = self.args

if not os.path.isdir(self.args["cache_dir"]):
os.mkdir(self.args["cache_dir"])
os.makedirs(self.args["cache_dir"])

mode = "dev" if evaluate else "train"
cached_features_file = os.path.join(args["cache_dir"], "cached_{}_{}_{}_{}_{}".format(mode, args["model_type"], args["max_seq_length"], self.num_labels, len(examples)))
Expand Down
2 changes: 1 addition & 1 deletion simpletransformers/ner/ner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def load_and_cache_examples(
)

if not os.path.isdir(self.args["cache_dir"]):
os.mkdir(self.args["cache_dir"])
os.makedirs(self.args["cache_dir"])

if os.path.exists(cached_features_file) and (
(not args["reprocess_input_data"] and not no_cache)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_and_cache_examples(
no_cache = args['no_cache']

if not os.path.isdir(self.args["cache_dir"]):
os.mkdir(self.args["cache_dir"])
os.makedirs(self.args["cache_dir"])

examples = get_examples(examples, is_training=not evaluate)

Expand Down Expand Up @@ -653,7 +653,7 @@ def evaluate(self, eval_data, output_dir):

prefix = "test"
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
os.makedirs(output_dir)

output_prediction_file = os.path.join(
output_dir, "predictions_{}.json".format(prefix)
Expand Down

0 comments on commit 36256f1

Please sign in to comment.