Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

batch packing for LM #413

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class Data(Component):
"""

__COMPONENT_TYPE__ = ComponentType.DATA_HANDLER
__EXPANSIBLE__ = True

class Config(Component.Config):
#: Specify where training/test/eval data come from. The default value
Expand Down
8 changes: 6 additions & 2 deletions pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def add_batch_stats(
self.batch_size.append(len(m_input[0]))
self.aggregate_data(self.all_num_tokens, targets[1])
now = time.time()
total_tokens = float(sum(targets[2]))
print(f"Tokens/s: {total_tokens / (now - self.time)}, ppl: {math.exp(loss)}")
if not n_batches % 1000:
total_tokens = float(sum(targets[2]))
print(
f"Tokens/s: {total_tokens / (now - self.time):.0f}, ppl: {math.exp(loss):.2f}",
flush=True,
)
self.time = now

def _reset(self):
Expand Down
7 changes: 5 additions & 2 deletions pytext/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def backward(optimizer, loss):
else:
# 1. Use automatic loss scaling to best use fp16 range
# 2. Clear handle's cache of casted parameters
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
if loss > 0:
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
else:
loss.backward()

Expand Down
2 changes: 1 addition & 1 deletion pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def prepare_task(

print("\nParameters: {}\n".format(config))
_set_cuda(config.use_cuda_if_available, device_id, world_size)
_set_fp16(config.use_fp16 and world_size == 1)
_set_fp16(config.use_fp16)
if config.random_seed is not None:
set_random_seeds(config.random_seed)

Expand Down