Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
batch packing for LM
Browse files Browse the repository at this point in the history
Summary: Take stream of tokens, pack it into square batch of size batch_size x max_seq_len with no padding (except last batch).

Differential Revision: D14518399

fbshipit-source-id: b7fcfa5af729b3b35ebaff604a12ed724916fd27
  • Loading branch information
borguz authored and facebook-github-bot committed Mar 19, 2019
1 parent 8edf377 commit 4fce7af
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import time

from pytext.common.constants import DatasetFieldName, Stage
from pytext.data import CommonMetadata
Expand Down Expand Up @@ -53,19 +54,21 @@ def get_model_select_metric(self, metrics) -> float:


class MaskedLMMetricReporter(LanguageModelMetricReporter):
UTTERANCE_COLUMN = "raw_text"
@classmethod
def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None):
return cls([ConsoleChannel()])

def aggregate_targets(self, new_batch):
self.aggregate_data(self.all_targets, new_batch[0])
self.aggregate_data(self.all_num_tokens, new_batch[1])
now = time.time()
print(f"Words/sec: {float(sum(new_batch[2])) / (now - self.time)}")
self.time = now

def _reset(self):
super()._reset()
self.all_num_tokens = []
self.time = time.time()

def _get_target_seq_lens(self):
print(self.all_num_tokens)
return self.all_num_tokens

def batch_context(self, batch):
return {"utterance": batch[self.UTTERANCE_COLUMN]}

0 comments on commit 4fce7af

Please sign in to comment.