forked from pytorch/translate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchar_source_hybrid.py
365 lines (319 loc) · 12.9 KB
/
char_source_hybrid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#!/usr/bin/env python3
import logging
import math
from ast import literal_eval
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import (
FairseqEncoder,
register_model,
register_model_architecture,
transformer as fairseq_transformer,
)
from fairseq.modules import SinusoidalPositionalEmbedding
from pytorch_translate import (
char_encoder,
char_source_model,
hybrid_transformer_rnn,
model_constants,
rnn,
transformer as pytorch_translate_transformer,
utils,
vocab_constants,
)
from pytorch_translate.common_layers import (
TransformerEncoderGivenEmbeddings,
VariableTracker,
)
from pytorch_translate.data.dictionary import TAGS
logger = logging.getLogger(__name__)
@register_model("char_source_hybrid")
class CharSourceHybridModel(hybrid_transformer_rnn.HybridTransformerRNNModel):
"""
An architecture combining hybrid Transformer/RNN with character-based
inputs (token embeddings created via character-input CNN)
"""
def __init__(self, task, encoder, decoder):
super().__init__(task, encoder, decoder)
@staticmethod
def add_args(parser):
hybrid_transformer_rnn.HybridTransformerRNNModel.add_args(parser)
parser.add_argument(
"--char-embed-dim",
type=int,
default=128,
metavar="N",
help=("Character embedding dimension."),
)
parser.add_argument(
"--char-cnn-params",
type=str,
metavar="EXPR",
help=("String experission, [(dim, kernel_size), ...]."),
)
parser.add_argument(
"--char-cnn-nonlinear-fn",
type=str,
default="tanh",
metavar="EXPR",
help=("Nonlinearity applied to char conv outputs. Values: relu, tanh."),
)
parser.add_argument(
"--char-cnn-num-highway-layers",
type=int,
default=0,
metavar="N",
help=("Char cnn encoder highway layers."),
)
parser.add_argument(
"--char-cnn-output-dim",
type=int,
default=-1,
metavar="N",
help="Output dim of the CNN layer. If set to -1, this is computed "
"from char-cnn-params.",
)
parser.add_argument(
"--use-pretrained-weights",
type=utils.bool_flag,
nargs="?",
const=True,
default=False,
help="Use pretrained weights for the character model including "
"the char embeddings, CNN filters, highway networks",
)
parser.add_argument(
"--finetune-pretrained-weights",
type=utils.bool_flag,
nargs="?",
const=True,
default=False,
help="Boolean flag to specify whether or not to update the "
"pretrained weights as part of training",
)
parser.add_argument(
"--pretrained-weights-file",
type=str,
default="",
help=("Weights file for loading pretrained weights"),
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, dst_dict = task.source_dictionary, task.target_dictionary
base_architecture(args)
assert hasattr(args, "char_source_dict_size"), (
"args.char_source_dict_size required. "
"should be set by load_binarized_dataset()"
)
assert hasattr(
args, "char_cnn_params"
), "Only char CNN is supported for the char encoder hybrid model"
args.embed_bytes = getattr(args, "embed_bytes", False)
# In case use_pretrained_weights is true, verify the model params
# are correctly set
if args.embed_bytes and getattr(args, "use_pretrained_weights", False):
char_source_model.verify_pretrain_params(args)
encoder = CharSourceHybridModel.build_encoder(args=args, src_dict=src_dict)
decoder = CharSourceHybridModel.build_decoder(
args=args, src_dict=src_dict, dst_dict=dst_dict
)
return cls(task, encoder, decoder)
def forward(
self, src_tokens, src_lengths, char_inds, word_lengths, prev_output_tokens
):
"""
Overriding FairseqEncoderDecoderModel.forward() due to different encoder
inputs.
"""
encoder_out = self.encoder(src_tokens, src_lengths, char_inds, word_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
@classmethod
def build_encoder(cls, args, src_dict):
# If we embed bytes then the number of indices is fixed and does not
# depend on the dictionary
if args.embed_bytes:
num_chars = vocab_constants.NUM_BYTE_INDICES + TAGS.__len__() + 1
else:
num_chars = args.char_source_dict_size
encoder_embed_tokens = pytorch_translate_transformer.build_embedding(
dictionary=src_dict,
embed_dim=args.encoder_embed_dim,
path=args.encoder_pretrained_embed,
freeze=args.encoder_freeze_embed,
)
return CharCNNEncoder(
args,
src_dict,
encoder_embed_tokens,
num_chars=num_chars,
embed_dim=args.char_embed_dim,
char_cnn_params=args.char_cnn_params,
char_cnn_nonlinear_fn=args.char_cnn_nonlinear_fn,
char_cnn_num_highway_layers=args.char_cnn_num_highway_layers,
char_cnn_output_dim=getattr(args, "char_cnn_output_dim", -1),
use_pretrained_weights=getattr(args, "use_pretrained_weights", False),
finetune_pretrained_weights=getattr(
args, "finetune_pretrained_weights", False
),
weights_file=getattr(args, "pretrained_weights_file", ""),
)
@classmethod
def build_decoder(cls, args, src_dict, dst_dict):
decoder_embed_tokens = pytorch_translate_transformer.build_embedding(
dictionary=dst_dict,
embed_dim=args.decoder_embed_dim,
path=args.decoder_pretrained_embed,
freeze=args.decoder_freeze_embed,
)
return hybrid_transformer_rnn.HybridRNNDecoder(
args, src_dict, dst_dict, decoder_embed_tokens
)
class CharCNNEncoder(FairseqEncoder):
"""
Character-level CNN encoder to generate word representations, as input to
transformer encoder.
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
num_chars=50,
embed_dim=32,
char_cnn_params="[(128, 3), (128, 5)]",
char_cnn_nonlinear_fn="tanh",
char_cnn_num_highway_layers=0,
char_cnn_output_dim=-1,
use_pretrained_weights=False,
finetune_pretrained_weights=False,
weights_file=None,
):
super().__init__(dictionary)
convolutions_params = literal_eval(char_cnn_params)
self.char_cnn_encoder = char_encoder.CharCNNModel(
dictionary,
num_chars,
embed_dim,
convolutions_params,
char_cnn_nonlinear_fn,
char_cnn_num_highway_layers,
char_cnn_output_dim,
use_pretrained_weights,
finetune_pretrained_weights,
weights_file,
)
self.embed_tokens = embed_tokens
token_embed_dim = embed_tokens.embedding_dim
self.word_layer_norm = nn.LayerNorm(token_embed_dim)
char_embed_dim = (
char_cnn_output_dim
if char_cnn_output_dim != -1
else sum(out_dim for (out_dim, _) in convolutions_params)
)
self.char_layer_norm = nn.LayerNorm(char_embed_dim)
self.word_dim = char_embed_dim + token_embed_dim
self.char_scale = math.sqrt(char_embed_dim / self.word_dim)
self.word_scale = math.sqrt(token_embed_dim / self.word_dim)
if self.word_dim != args.encoder_embed_dim:
self.word_to_transformer_embed = fairseq_transformer.Linear(
self.word_dim, args.encoder_embed_dim
)
self.dropout = args.dropout
self.padding_idx = dictionary.pad()
self.embed_positions = fairseq_transformer.PositionalEmbedding(
1024,
args.encoder_embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
self.transformer_encoder_given_embeddings = TransformerEncoderGivenEmbeddings(
args=args, proj_to_decoder=False
)
# Variable tracker
self.tracker = VariableTracker()
# Initialize adversarial mode
self.set_gradient_tracking_mode(False)
self.set_embed_noising_mode(False)
# disables sorting and word-length thresholding if True
# (enables ONNX tracing of length-sorted input with batch_size = 1)
self.onnx_export_model = False
def prepare_for_onnx_export_(self):
self.onnx_export_model = True
def set_gradient_tracking_mode(self, mode=True):
""" This allows AdversarialTrainer to turn on retrain_grad when
running adversarial example generation model."""
self.tracker.reset()
self.track_gradients = mode
def set_embed_noising_mode(self, mode=True):
"""This allows adversarial trainer to turn on and off embedding noising
layers. In regular training, this mode is off, and it is not included
in forward pass.
"""
self.embed_noising_mode = mode
def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
self.tracker.reset()
# char_inds has shape (batch_size, max_words_per_sent, max_word_len)
bsz, seqlen, maxchars = char_inds.size()
# char_cnn_encoder takes input (max_word_length, total_words)
char_inds_flat = char_inds.view(-1, maxchars).t()
# output (total_words, encoder_dim)
char_cnn_output = self.char_cnn_encoder(char_inds_flat)
x = char_cnn_output.view(bsz, seqlen, char_cnn_output.shape[-1])
x = x.transpose(0, 1) # (seqlen, bsz, char_cnn_output_dim)
x = self.char_layer_norm(x)
x = self.char_scale * x
embedded_tokens = self.embed_tokens(src_tokens)
# (seqlen, bsz, token_embed_dim)
embedded_tokens = embedded_tokens.transpose(0, 1)
embedded_tokens = self.word_layer_norm(embedded_tokens)
embedded_tokens = self.word_scale * embedded_tokens
x = torch.cat([x, embedded_tokens], dim=2)
self.tracker.track(x, "token_embeddings", retain_grad=self.track_gradients)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.word_to_transformer_embed is not None:
x = self.word_to_transformer_embed(x)
positions = self.embed_positions(src_tokens)
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask (B x T)
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
x = self.transformer_encoder_given_embeddings(
x=x, positions=positions, encoder_padding_mask=encoder_padding_mask
)
# tracing requires a tensor value
if self.onnx_export_model and encoder_padding_mask is None:
encoder_padding_mask = torch.Tensor([]).type_as(src_tokens)
return x, src_tokens, encoder_padding_mask
def reorder_encoder_out(self, encoder_out, new_order):
(x, src_tokens, encoder_padding_mask) = encoder_out
if x is not None:
x = x.index_select(1, new_order)
if src_tokens is not None:
src_tokens = src_tokens.index_select(0, new_order)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
return (x, src_tokens, encoder_padding_mask)
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if "encoder.embed_positions.weights" in state_dict:
del state_dict["encoder.embed_positions.weights"]
state_dict["encoder.embed_positions._float_tensor"] = torch.FloatTensor(1)
return state_dict
@register_model_architecture("char_source_hybrid", "char_source_hybrid")
def base_architecture(args):
# default architecture
hybrid_transformer_rnn.base_architecture(args)
args.char_cnn_params = getattr(args, "char_cnn_params", "[(50, 1), (100,2)]")
args.char_cnn_nonlinear_fn = getattr(args, "chr_cnn_nonlinear_fn", "relu")
args.char_cnn_num_highway_layers = getattr(args, "char_cnn_num_highway_layers", "2")