@@ -5,12 +5,110 @@ version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/t
5
5
[ ![ GitHub
6
6
Issues] ( https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg )] ( https://github.com/tensorflow/tensor2tensor/issues )
7
7
[ ![ Contributions
8
- welcome] ( https://img.shields.io/badge/contributions-welcome-brightgreen.svg )] ( CONTRIBUTING.md )
8
+ welcome] ( https://img.shields.io/badge/contributions-welcome-brightgreen.svg )] ( ../ CONTRIBUTING.md)
9
9
[ ![ Gitter] ( https://img.shields.io/gitter/room/nwjs/nw.js.svg )] ( https://gitter.im/tensor2tensor/Lobby )
10
10
[ ![ License] ( https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg )] ( https://opensource.org/licenses/Apache-2.0 )
11
11
12
12
Here we show how to create your own model in T2T.
13
13
14
- ## The T2TModel class
14
+ ## The T2TModel class - abstract base class for models
15
15
16
- TODO: complete.
16
+ ` T2TModel ` has three typical usages:
17
+
18
+ 1 . Estimator: The method ` make_estimator_model_fn ` builds a ` model_fn ` for
19
+ the tf.Estimator workflow of training, evaluation, and prediction.
20
+ It performs the method ` call ` , which performs the core computation,
21
+ followed by ` estimator_spec_train ` , ` estimator_spec_eval ` , or
22
+ ` estimator_spec_predict ` depending on the tf.Estimator mode.
23
+ 2 . Layer: The method ` call ` enables ` T2TModel ` to be used a callable by
24
+ itself. It calls the following methods:
25
+
26
+ * ` bottom ` , which transforms features according to ` problem_hparams ` ' input
27
+ and target ` Modality ` s;
28
+ * ` body ` , which takes features and performs the core model computation to
29
+ return output and any auxiliary loss terms;
30
+ * ` top ` , which takes features and the body output, and transforms them
31
+ according to ` problem_hparams ` ' input and target ` Modality ` s to return
32
+ the final logits;
33
+ * ` loss ` , which takes the logits, forms any missing training loss, and sums
34
+ all loss terms.
35
+ 3 . Inference: The method ` infer ` enables ` T2TModel ` to make sequence
36
+ predictions by itself.
37
+
38
+
39
+ ## Creating your own model
40
+
41
+ 1 . Create class that extends T2TModel
42
+ in this example it will be a copy of existing basic fully connected network:
43
+
44
+ ``` python
45
+ from tensor2tensor.utils import t2t_model
46
+
47
+ class MyFC (t2t_model .T2TModel ):
48
+ pass
49
+ ```
50
+
51
+
52
+ 2 . Implement body method:
53
+
54
+ ``` python
55
+ class MyFC (t2t_model .T2TModel ):
56
+ def body (self , features ):
57
+ hparams = self .hparams
58
+ x = features[" inputs" ]
59
+ shape = common_layers.shape_list(x)
60
+ x = tf.reshape(x, [- 1 , shape[1 ] * shape[2 ] * shape[3 ]]) # Flatten input as in T2T they are all 4D vectors
61
+ for i in range (hparams.num_hidden_layers): # create layers
62
+ x = tf.layers.dense(x, hparams.hidden_size, name = " layer_%d " % i)
63
+ x = tf.nn.dropout(x, keep_prob = 1.0 - hparams.dropout)
64
+ x = tf.nn.relu(x)
65
+ return tf.expand_dims(tf.expand_dims(x, axis = 1 ), axis = 1 ) # 4D For T2T.
66
+ ```
67
+
68
+
69
+ Method Signature:
70
+
71
+ * Args:
72
+ * features: dict of str to Tensor, where each Tensor has shape [ batch_size,
73
+ ..., hidden_size] . It typically contains keys ` inputs ` and ` targets ` .
74
+
75
+ * Returns one of:
76
+ * output: Tensor of pre-logit activations with shape [ batch_size, ...,
77
+ hidden_size] .
78
+ * losses: Either single loss as a scalar, a list, a Tensor (to be averaged),
79
+ or a dictionary of losses. If losses is a dictionary with the key
80
+ "training", losses[ "training"] is considered the final training
81
+ loss and output is considered logits; self.top and self.loss will
82
+ be skipped.
83
+
84
+ 3 . Register your model
85
+
86
+ ``` python
87
+ from tensor2tensor.utils import registry
88
+
89
+ @registry.register_model
90
+ class MyFC (t2t_model .T2TModel ):
91
+ # ...
92
+ ```
93
+
94
+
95
+ 3 . Use it with t2t tools as any other model
96
+
97
+ Have in mind that names are translated from camel case to snake_case ` MyFC ` -> ` my_fc `
98
+ and that you need to point t2t to directory containing your model with ` t2t_usr_dir ` switch.
99
+ For example if you want to train model on gcloud with 1 GPU worker on IMDB sentiment task you can run your model
100
+ by executing following command from your model class directory.
101
+
102
+ ``` bash
103
+ t2t-trainer \
104
+ --model=my_fc \
105
+ --t2t_usr_dir=.
106
+ --cloud_mlengine --worker_gpu=1 \
107
+ --generate_data \
108
+ --data_dir=' gs://data' \
109
+ --output_dir=' gs://out' \
110
+ --problem=sentiment_imdb \
111
+ --hparams_set=basic_fc_small \
112
+ --train_steps=10000 \
113
+ --eval_steps=10 \
114
+ ```
0 commit comments