Skip to content

Commit

Permalink
Added silent mode for console outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hlo-world authored Nov 4, 2019
1 parent 2de6dce commit 566e1a1
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions simpletransformers/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# coding: utf-8


from __future__ import absolute_import, division, print_function

import os
Expand Down Expand Up @@ -101,6 +100,8 @@ def __init__(self, model_type, model_name, num_labels=2, args=None, use_cuda=Tru

"process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
"n_gpu": 1,

"silent": False,
}

if args:
Expand Down Expand Up @@ -151,7 +152,8 @@ def train_model(self, train_df, output_dir=None, show_running_loss=True, args=No
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

print("Training of {} model complete. Saved to {}.".format(self.args["model_type"], output_dir))
if not args["silent"]:
print("Training of {} model complete. Saved to {}.".format(self.args["model_type"], output_dir))


def eval_model(self, eval_df, output_dir=None, verbose=False, **kwargs):
Expand Down Expand Up @@ -180,8 +182,9 @@ def eval_model(self, eval_df, output_dir=None, verbose=False, **kwargs):
result, model_outputs, wrong_preds = self.evaluate(eval_df, output_dir, **kwargs)
self.results.update(result)

if verbose:
print(self.results)
if not args["silent"]:
if verbose:
print(self.results)

return result, model_outputs, wrong_preds

Expand Down Expand Up @@ -219,7 +222,7 @@ def evaluate(self, eval_df, output_dir, prefix="", **kwargs):
out_label_ids = None
model.eval()

for batch in tqdm(eval_dataloader):
for batch in tqdm(eval_dataloader, disable=args["silent"]):
batch = tuple(t.to(device) for t in batch)

with torch.no_grad():
Expand Down Expand Up @@ -274,9 +277,11 @@ def load_and_cache_examples(self, examples, evaluate=False, no_cache=False):

if os.path.exists(cached_features_file) and not args["reprocess_input_data"] and not no_cache:
features = torch.load(cached_features_file)
print(f"Features loaded from cache at {cached_features_file}")
if not args["silent"]:
print(f"Features loaded from cache at {cached_features_file}")
else:
print(f"Converting to features started.")
if not args["silent"]:
print(f"Converting to features started.")
features = convert_examples_to_features(
examples,
args["max_seq_length"],
Expand Down Expand Up @@ -360,20 +365,21 @@ def train(self, train_dataset, output_dir, show_running_loss=True):
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args["num_train_epochs"]), desc="Epoch")
train_iterator = trange(int(args["num_train_epochs"]), desc="Epoch", disable=args["silent"])

model.train()
for _ in train_iterator:
# epoch_iterator = tqdm(train_dataloader, desc="Iteration")
for step, batch in enumerate(tqdm(train_dataloader, desc="Current iteration")):
# epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args["silent"])
for step, batch in enumerate(tqdm(train_dataloader, desc="Current iteration", disable=args["silent"])):
batch = tuple(t.to(device) for t in batch)

inputs = self._get_inputs_dict(batch)
outputs = model(**inputs)
# model outputs are always tuple in pytorch-transformers (see doc)
loss = outputs[0]
if show_running_loss:
print("\rRunning loss: %f" % loss, end="")
if not args["silent"]:
print("\rRunning loss: %f" % loss, end="")

if args["gradient_accumulation_steps"] > 1:
loss = loss / args["gradient_accumulation_steps"]
Expand Down Expand Up @@ -487,7 +493,7 @@ def predict(self, to_predict):
preds = None
out_label_ids = None

for batch in tqdm(eval_dataloader):
for batch in tqdm(eval_dataloader, disable=args["silent"]):
model.eval()
batch = tuple(t.to(device) for t in batch)

Expand Down

0 comments on commit 566e1a1

Please sign in to comment.