Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0e0d449

Browse files
hbrylkowskicopybara-github
authored andcommitted
Merge of PR #1589
PiperOrigin-RevId: 252130487
1 parent 7ad6c7f commit 0e0d449

File tree

1 file changed

+101
-3
lines changed

1 file changed

+101
-3
lines changed

docs/new_model.md

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,110 @@ version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/t
55
[![GitHub
66
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
77
[![Contributions
8-
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](../CONTRIBUTING.md)
99
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

1212
Here we show how to create your own model in T2T.
1313

14-
## The T2TModel class
14+
## The T2TModel class - abstract base class for models
1515

16-
TODO: complete.
16+
`T2TModel` has three typical usages:
17+
18+
1. Estimator: The method `make_estimator_model_fn` builds a `model_fn` for
19+
the tf.Estimator workflow of training, evaluation, and prediction.
20+
It performs the method `call`, which performs the core computation,
21+
followed by `estimator_spec_train`, `estimator_spec_eval`, or
22+
`estimator_spec_predict` depending on the tf.Estimator mode.
23+
2. Layer: The method `call` enables `T2TModel` to be used a callable by
24+
itself. It calls the following methods:
25+
26+
* `bottom`, which transforms features according to `problem_hparams`' input
27+
and target `Modality`s;
28+
* `body`, which takes features and performs the core model computation to
29+
return output and any auxiliary loss terms;
30+
* `top`, which takes features and the body output, and transforms them
31+
according to `problem_hparams`' input and target `Modality`s to return
32+
the final logits;
33+
* `loss`, which takes the logits, forms any missing training loss, and sums
34+
all loss terms.
35+
3. Inference: The method `infer` enables `T2TModel` to make sequence
36+
predictions by itself.
37+
38+
39+
## Creating your own model
40+
41+
1. Create class that extends T2TModel
42+
in this example it will be a copy of existing basic fully connected network:
43+
44+
```python
45+
from tensor2tensor.utils import t2t_model
46+
47+
class MyFC(t2t_model.T2TModel):
48+
pass
49+
```
50+
51+
52+
2. Implement body method:
53+
54+
```python
55+
class MyFC(t2t_model.T2TModel):
56+
def body(self, features):
57+
hparams = self.hparams
58+
x = features["inputs"]
59+
shape = common_layers.shape_list(x)
60+
x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]]) # Flatten input as in T2T they are all 4D vectors
61+
for i in range(hparams.num_hidden_layers): # create layers
62+
x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i)
63+
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
64+
x = tf.nn.relu(x)
65+
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.
66+
```
67+
68+
69+
Method Signature:
70+
71+
* Args:
72+
* features: dict of str to Tensor, where each Tensor has shape [batch_size,
73+
..., hidden_size]. It typically contains keys `inputs` and `targets`.
74+
75+
* Returns one of:
76+
* output: Tensor of pre-logit activations with shape [batch_size, ...,
77+
hidden_size].
78+
* losses: Either single loss as a scalar, a list, a Tensor (to be averaged),
79+
or a dictionary of losses. If losses is a dictionary with the key
80+
"training", losses["training"] is considered the final training
81+
loss and output is considered logits; self.top and self.loss will
82+
be skipped.
83+
84+
3. Register your model
85+
86+
```python
87+
from tensor2tensor.utils import registry
88+
89+
@registry.register_model
90+
class MyFC(t2t_model.T2TModel):
91+
# ...
92+
```
93+
94+
95+
3. Use it with t2t tools as any other model
96+
97+
Have in mind that names are translated from camel case to snake_case `MyFC` -> `my_fc`
98+
and that you need to point t2t to directory containing your model with `t2t_usr_dir` switch.
99+
For example if you want to train model on gcloud with 1 GPU worker on IMDB sentiment task you can run your model
100+
by executing following command from your model class directory.
101+
102+
```bash
103+
t2t-trainer \
104+
--model=my_fc \
105+
--t2t_usr_dir=.
106+
--cloud_mlengine --worker_gpu=1 \
107+
--generate_data \
108+
--data_dir='gs://data' \
109+
--output_dir='gs://out' \
110+
--problem=sentiment_imdb \
111+
--hparams_set=basic_fc_small \
112+
--train_steps=10000 \
113+
--eval_steps=10 \
114+
```

0 commit comments

Comments
 (0)