Skip to content

Commit

Permalink
batch packing for LM (facebookresearch#413)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#413

Take stream of tokens, pack it into square batch of size batch_size x max_seq_len with no padding (except last batch).

Reviewed By: jingfeidu

Differential Revision: D14518399

fbshipit-source-id: 8de48a688d1525350c3d87d059244f8f9f1c9070
  • Loading branch information
borguz authored and facebook-github-bot committed Apr 2, 2019
1 parent 37c6468 commit 19e6274
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class Data(Component):
"""

__COMPONENT_TYPE__ = ComponentType.DATA_HANDLER
__EXPANSIBLE__ = True

class Config(Component.Config):
#: Specify where training/test/eval data come from. The default value
Expand Down
8 changes: 6 additions & 2 deletions pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def add_batch_stats(
self.batch_size.append(len(m_input[0]))
self.aggregate_data(self.all_num_tokens, targets[1])
now = time.time()
total_tokens = float(sum(targets[2]))
print(f"Tokens/s: {total_tokens / (now - self.time)}, ppl: {math.exp(loss)}")
if not n_batches % 1000:
total_tokens = float(sum(targets[2]))
print(
f"Tokens/s: {total_tokens / (now - self.time):.0f}, ppl: {math.exp(loss):.2f}",
flush=True,
)
self.time = now

def _reset(self):
Expand Down
7 changes: 5 additions & 2 deletions pytext/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ def backward(optimizer, loss):
else:
# 1. Use automatic loss scaling to best use fp16 range
# 2. Clear handle's cache of casted parameters
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
if loss > 0:
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
else:
loss.backward()

Expand Down
2 changes: 1 addition & 1 deletion pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def prepare_task(

print("\nParameters: {}\n".format(config))
_set_cuda(config.use_cuda_if_available, device_id, world_size)
_set_fp16(config.use_fp16 and world_size == 1)
_set_fp16(config.use_fp16)
if config.random_seed is not None:
set_random_seeds(config.random_seed)

Expand Down

0 comments on commit 19e6274

Please sign in to comment.