Skip to content

Commit 892f76c

Browse files
artitwkpe
authored andcommitted
internal merge of PR tensorflow#1290
PiperOrigin-RevId: 224943245
1 parent 18c2b3c commit 892f76c

File tree

4 files changed

+67
-40
lines changed

4 files changed

+67
-40
lines changed

docs/walkthrough.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pip install tensor2tensor && t2t-trainer \
4747
### Contents
4848

4949
* [Suggested Datasets and Models](#suggested-datasets-and-models)
50+
* [Mathematical Language Understanding](#mathematical-language-understanding)
5051
* [Story, Question and Answer](#story-question-and-answer)
5152
* [Image Classification](#image-classification)
5253
* [Image Generation](#image-generation)
@@ -79,6 +80,24 @@ hyperparameters that we know works well in our setup. We usually
7980
run either on Cloud TPUs or on 8-GPU machines; you might need
8081
to modify the hyperparameters if you run on a different setup.
8182

83+
### Mathematical Language Understanding
84+
85+
For evaluating mathematical expressions at the character level involving addition, subtraction and multiplication of both positive and negative decimal numbers with variable digits assigned to symbolic variables, use
86+
87+
* the [MLU](https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz) data-set:
88+
`--problem=mathematical_language_understanding`
89+
90+
You can try solving the problem with different transformer models and hyperparameters as described in the [paper](https://arxiv.org/abs/1812.02825):
91+
* Standard transformer:
92+
`--model=transformer`
93+
`--hparams_set=transformer_tiny`
94+
* Universal transformer:
95+
`--model=universal_transformer`
96+
`--hparams_set=universal_transformer_tiny`
97+
* Adaptive universal transformer:
98+
`--model=universal_transformer`
99+
`--hparams_set=adaptive_universal_transformer_tiny`
100+
82101
### Story, Question and Answer
83102

84103
For answering questions based on a story, use
@@ -464,5 +483,6 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
464483
* [Fast Decoding in Sequence Models using Discrete Latent Variables](https://arxiv.org/abs/1803.03382)
465484
* [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
466485
* [Universal Transformers](https://arxiv.org/abs/1807.03819)
486+
* [Attending to Mathematical Language with Transformers](https://arxiv.org/abs/1812.02825)
467487

468488
*Note: This is not an official Google product.*

tensor2tensor/data_generators/babi_qa.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,11 @@ def _prepare_babi_data(tmp_dir, data_dir):
109109
tf.gfile.MakeDirs(data_dir)
110110

111111
file_path = os.path.join(tmp_dir, _TAR)
112-
headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
112+
headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) "
113+
"AppleWebKit/537.36 (KHTML, like Gecko) "
114+
"Chrome/63.0.3239.132 Safari/537.36"}
113115
resp = requests.get(_URL, headers=headers)
114-
with open(file_path, 'wb') as f:
116+
with open(file_path, "wb") as f:
115117
f.write(resp.content)
116118

117119
tar = tarfile.open(file_path)
@@ -192,10 +194,12 @@ def _all_task_raw_data_generator(tmp_dir, data_file, dataset_split):
192194

193195
tf.logging.info("Preparing dataset of all task together")
194196
globe_name = ("*_{}.txt")
197+
mode_name = "test"
198+
if dataset_split == problem.DatasetSplit.TRAIN:
199+
mode_name = "train"
195200
files_name = os.path.join(
196201
tmp_dir, _DIR_NAME, subset,
197-
globe_name.format("train" if dataset_split == problem.DatasetSplit.TRAIN
198-
else "test"))
202+
globe_name.format(mode_name))
199203
with tf.gfile.GFile(data_file, "wb") as outfile:
200204
for filename in tf.gfile.Glob(files_name):
201205
if filename == data_file:
@@ -459,6 +463,7 @@ def hparams(self, defaults, unused_model_hparams):
459463
if "context" in p.vocab_size:
460464
del p.vocab_size["context"]
461465

466+
462467
def _problems_to_register():
463468
"""Problems for which we want to create datasets.
464469
Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2018 Artit Wangperawong artitw@gmail.com
2+
# Copyright 2018 The Tensor2Tensor Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -15,31 +15,30 @@
1515

1616
r"""Data generators for the Mathematical Language Understanding dataset.
1717
18-
The training and test data were generated by assigning symbolic variables
19-
either positive or negative decimal integers and then describing the algebraic
20-
operation to perform. We restrict our variable assignments to the range
21-
x,y->[-1000,1000) and the operations to the set {+,-,*}. To ensure that the
22-
model embraces symbolic variables, the order in which x and y appears in the
23-
expression is randomly chosen. For instance, an input string contrasting from
24-
the example shown above might be y=129,x=531,x-y. Each input string is
25-
accompanied by its target string, which is the evaluation of the mathematical
26-
expression. For this study, all targets considered are decimal integers
27-
represented at the character level. About 12 million unique samples were thus
28-
generated and randomly split into training and test sets at an approximate
29-
ratio of 9:1, respectively.
18+
The training and test data were generated by assigning symbolic variables
19+
either positive or negative decimal integers and then describing the algebraic
20+
operation to perform. We restrict our variable assignments to the range
21+
x,y->[-1000,1000) and the operations to the set {+,-,*}. To ensure that the
22+
model embraces symbolic variables, the order in which x and y appears in the
23+
expression is randomly chosen. For instance, an input string contrasting from
24+
the example shown above might be y=129,x=531,x-y. Each input string is
25+
accompanied by its target string, which is the evaluation of the mathematical
26+
expression. For this study, all targets considered are decimal integers
27+
represented at the character level. About 12 million unique samples were thus
28+
generated and randomly split into training and test sets at an approximate
29+
ratio of 9:1, respectively.
3030
3131
For more information check the following paper:
32-
Artit Wangperawong. Attending to Mathematical Language with Transformers,
33-
arXiv:1812.02825.
34-
Available at: https://arxiv.org/abs/1812.02825
35-
32+
Artit Wangperawong. Attending to Mathematical Language with Transformers,
33+
arXiv:1812.02825 (https://arxiv.org/abs/1812.02825).
3634
"""
3735

3836
from __future__ import absolute_import
3937
from __future__ import division
4038
from __future__ import print_function
4139

4240
import os
41+
import tarfile
4342

4443
from tensor2tensor.data_generators import generator_utils
4544
from tensor2tensor.data_generators import problem
@@ -48,9 +47,13 @@
4847

4948
import tensorflow as tf
5049

50+
5151
@registry.register_problem
5252
class MathematicalLanguageUnderstanding(text_problems.Text2TextProblem):
53-
URL = "https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz"
53+
"""Mathematical language understanding, see arxiv.org/abs/1812.02825."""
54+
55+
URL = ("https://art.wangperawong.com/mathematical_language_understanding"
56+
"_train.tar.gz")
5457

5558
@property
5659
def vocab_type(self):
@@ -71,34 +74,31 @@ def is_generate_per_split(self):
7174
return False
7275

7376
def generate_samples(self, data_dir, tmp_dir, dataset_split):
74-
"""Downloads and extracts the dataset and generates examples
77+
"""Downloads and extracts the dataset and generates examples.
7578
7679
Args:
77-
tmp_dir: temp directory to download and extract the dataset
7880
data_dir: The base directory where data and vocab files are stored.
81+
tmp_dir: temp directory to download and extract the dataset.
82+
dataset_split: split of the data-set.
7983
80-
Returns:
81-
data generator
84+
Yields:
85+
The data examples.
8286
"""
83-
8487
if not tf.gfile.Exists(tmp_dir):
8588
tf.gfile.MakeDirs(tmp_dir)
8689

8790
if not tf.gfile.Exists(data_dir):
8891
tf.gfile.MakeDirs(data_dir)
8992

90-
# Download and extract
93+
# Download and extract.
9194
compressed_filename = os.path.basename(self.URL)
92-
download_path = generator_utils.maybe_download(tmp_dir, compressed_filename,
93-
self.URL)
94-
95+
download_path = generator_utils.maybe_download(
96+
tmp_dir, compressed_filename, self.URL)
9597
with tarfile.open(download_path, "r:gz") as tar:
9698
tar.extractall(tmp_dir)
97-
98-
filepath = os.path.join(tmp_dir, "mathematical_language_understanding_train.txt")
99-
100-
with open(filepath, 'r') as fp:
99+
filepath = os.path.join(tmp_dir,
100+
"mathematical_language_understanding_train.txt")
101+
with open(filepath, "r") as fp:
101102
for l in fp:
102-
prob, ans = l.strip().split(':')
103+
prob, ans = l.strip().split(":")
103104
yield {"inputs": prob, "targets": ans}
104-

tensor2tensor/models/research/universal_transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,12 @@ def _greedy_infer(self, features, decode_length, use_tpu=False):
240240
Raises:
241241
NotImplementedError: If there are multiple data shards.
242242
"""
243-
return (self._slow_greedy_infer_tpu(features, decode_length) if use_tpu else
244-
self._slow_greedy_infer(features, decode_length))
243+
if use_tpu:
244+
return self._slow_greedy_infer_tpu(features, decode_length)
245+
return self._slow_greedy_infer(features, decode_length)
245246

246-
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, use_tpu=False):
247+
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha,
248+
use_tpu=False):
247249
"""Beam search decoding.
248250
249251
Args:

0 commit comments

Comments
 (0)