Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Masked LM (#404)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #404

Implements model and masking logic for BERT style masked LM training.

Differential Revision: D14494507

fbshipit-source-id: 45d4fed25dbee4af688d78647dfc785e8cbead64
  • Loading branch information
borguz authored and facebook-github-bot committed Mar 20, 2019
1 parent 1465244 commit 236bc61
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 43 deletions.
1 change: 1 addition & 0 deletions pytext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __eq__(self, other):
PAD = SpecialToken("__PAD__")
BOS = SpecialToken("__BEGIN_OF_SENTENCE__")
EOS = SpecialToken("__END_OF_SENTENCE__")
MASK = SpecialToken("__MASK__")


class Vocabulary:
Expand Down
9 changes: 7 additions & 2 deletions pytext/metric_reporters/classification_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,13 @@ def __init__(
self.target_label = target_label

@classmethod
def from_config(cls, config, meta: CommonMetadata):
return cls.from_config_and_label_names(config, meta.target.vocab.itos)
def from_config(cls, config, meta: CommonMetadata, tensorizers=None):
# TODO: refactor metric reporting and remove this hack
if tensorizers:
labels = list(tensorizers["labels"].labels)
else:
labels = meta.target.vocab.itos
return cls.from_config_and_label_names(config, labels)

@classmethod
def from_config_and_label_names(cls, config, label_names: List[str]):
Expand Down
27 changes: 23 additions & 4 deletions pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LanguageModelMetricReporter(MetricReporter):
lower_is_better = True

@classmethod
def from_config(cls, config, meta: CommonMetadata):
def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None):
return cls(
[ConsoleChannel(), LanguageModelChannel((Stage.TEST,), config.output_path)]
)
Expand All @@ -34,13 +34,14 @@ def calculate_metric(self) -> LanguageModelMetric:
# In language model self.total_loss is the loss per word
return compute_language_model_metric(self.total_loss)

def _get_target_seq_lens(self):
return self.all_context[DatasetFieldName.TARGET_SEQ_LENS]

def calculate_loss(self) -> float:
total_loss = n_words = pos = 0
for loss, batch_size in zip(self.all_loss, self.batch_size):
num_words_in_batch = sum(
self.all_context[DatasetFieldName.TARGET_SEQ_LENS][
pos : pos + batch_size
]
self._get_target_seq_lens()[pos : pos + batch_size]
)
pos = pos + batch_size
total_loss += loss * num_words_in_batch
Expand All @@ -49,3 +50,21 @@ def calculate_loss(self) -> float:

def get_model_select_metric(self, metrics) -> float:
return metrics.perplexity_per_word


class MaskedLMMetricReporter(LanguageModelMetricReporter):
UTTERANCE_COLUMN = "raw_text"

def aggregate_targets(self, new_batch):
self.aggregate_data(self.all_targets, new_batch[0])
self.aggregate_data(self.all_num_tokens, new_batch[1])

def _reset(self):
super()._reset()
self.all_num_tokens = []

def _get_target_seq_lens(self):
return self.all_num_tokens

def batch_context(self, batch):
return {"utterance": batch[self.UTTERANCE_COLUMN]}
5 changes: 4 additions & 1 deletion pytext/metric_reporters/metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def add_batch_stats(
self.all_context[key] = []
self.aggregate_data(self.all_context[key], val)
self.all_loss.append(loss)
self.batch_size.append(len(targets))
self.batch_size.append(len(m_input[0]))

def aggregate_preds(self, new_batch):
self.aggregate_data(self.all_preds, new_batch)
Expand Down Expand Up @@ -115,6 +115,9 @@ def _make_simple_list(cls, data):
def add_channel(self, channel):
self.channels.append(channel)

def batch_context(self, batch):
return {}

def calculate_loss(self):
"""
Calculate the average loss for all aggregated batch
Expand Down
2 changes: 1 addition & 1 deletion pytext/metric_reporters/regression_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Config(MetricReporter.Config):
pass

@classmethod
def from_config(cls, config):
def from_config(cls, config, tensorizers=None):
return cls([ConsoleChannel()])

def calculate_metric(self):
Expand Down
7 changes: 5 additions & 2 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def get_param_groups_for_optimizer(self) -> List[Dict[str, List[nn.Parameter]]]:
def train_batch(self, batch):
model_inputs = self.arrange_model_inputs(batch)
model_outputs = self(*model_inputs)
loss = self.get_loss(model_outputs, self.arrange_targets(batch), None)
predictions, scores = self.get_pred(model_outputs)
targets = self.arrange_targets(batch)
loss = self.get_loss(model_outputs, targets, None)
predictions, scores = self.get_pred(model_outputs)
metric_data = (predictions, targets, scores, loss, model_inputs)
return loss, metric_data

Expand All @@ -159,6 +159,9 @@ def arrange_targets(self, tensor_dict):
# should raise NotImplementedError after migration is done
pass

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
pass


class Model(BaseModel):
"""
Expand Down
33 changes: 13 additions & 20 deletions pytext/models/output_layers/lm_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import torch
import torch.nn.functional as F
from pytext.common.constants import DatasetFieldName
from pytext.config.component import create_loss
from pytext.fields import FieldMeta
from pytext.data.utils import PAD
from pytext.loss import CrossEntropyLoss, Loss

from .output_layer_base import OutputLayerBase
Expand All @@ -29,11 +28,13 @@ class Config(OutputLayerBase.Config):
loss: CrossEntropyLoss.Config = CrossEntropyLoss.Config()

@classmethod
def from_config(cls, config: Config, metadata: FieldMeta):
def from_config(cls, config: Config, metadata=None, labels=None):
vocab = labels or metadata.vocab.itos
pad_token_idx = metadata.pad_token_idx if metadata else vocab.idx[PAD]
return cls(
metadata.vocab.itos,
create_loss(config.loss, ignore_index=metadata.pad_token_idx),
pad_token_idx=metadata.pad_token_idx,
vocab,
create_loss(config.loss, ignore_index=pad_token_idx),
pad_token_idx=pad_token_idx,
)

def __init__(
Expand Down Expand Up @@ -67,12 +68,14 @@ def get_loss(
torch.Tensor: Word prediction loss.
"""
if isinstance(target, tuple):
target = target[0]
# flatten the logit from [batch_size, seq_lens, dim] to
# [batch_size * seq_lens, dim]
return self.loss_fn(logit.view(-1, logit.size()[-1]), target.view(-1), reduce)

def get_pred(
self, logit: torch.Tensor, target: torch.Tensor, context: Dict[str, Any]
self, logit: torch.Tensor, *args, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute and return prediction and scores from the model.
Prediction is computed using argmax over the word label/target space.
Expand All @@ -91,19 +94,9 @@ def get_pred(
Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores.
"""
# Shape of logit: (bsize x seq_len x vocab)
# Reshape m_out to (bsize x vocab x seq_len) for cross_entropy_loss
logit = logit.transpose(1, 2)
# loss dim: (bsize x seq_len)
loss = F.cross_entropy(
logit, target, reduce=False, ignore_index=self.pad_token_idx
)
# context[DatasetFieldName.SEQ_LENS] s the length of each sequence
# sequence_loss is the loss per word for each sequence in the batch
# sequence_loss dim: (bsize,)
sequence_loss = loss.sum(1) / context[DatasetFieldName.TARGET_SEQ_LENS].float()
scores = self.calculate_perplexity(sequence_loss)
return scores, scores
preds = torch.max(logit, 2)[1]
scores = F.log_softmax(logit, 2)
return preds, scores

@staticmethod
def calculate_perplexity(sequence_loss: torch.Tensor) -> torch.Tensor:
Expand Down
18 changes: 5 additions & 13 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def from_config(cls, config: Config, unused_metadata=None, model_state=None):
# This is the only place right now that the task actually cares about which
# features and tensors are being used. This is a strong tie between
# the implementation of the model and the metric reporter.
metric_reporter = cls.create_metric_reporter(config, tensorizers)
metric_reporter = create_component(
ComponentType.METRIC_REPORTER,
config.metric_reporter,
tensorizers=tensorizers,
)
trainer = create_trainer(config.trainer, model)
return cls(data, model, metric_reporter, trainer)

Expand Down Expand Up @@ -206,22 +210,10 @@ class Config(NewTask.Config):
ClassificationMetricReporter.Config()
)

# The existence of this function is a pretty good argument for having
# the metric reporter be owned internally at least in some way by the model
@classmethod
def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
return ClassificationMetricReporter.from_config_and_label_names(
config.metric_reporter, list(tensorizers["labels"].labels)
)


class NewDocumentRegression(NewTask):
class Config(NewTask.Config):
model: Model.Config = DocRegressionModel.Config()
metric_reporter: RegressionMetricReporter.Config = (
RegressionMetricReporter.Config()
)

@classmethod
def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
return RegressionMetricReporter.from_config(config.metric_reporter)

0 comments on commit 236bc61

Please sign in to comment.