forked from keras-team/keras-io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorflow_numpy_models.py
351 lines (277 loc) · 9.5 KB
/
tensorflow_numpy_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""
Title: Writing Keras Models With TensorFlow NumPy
Author: [lukewood](https://lukewood.xyz)
Date created: 2021/08/28
Last modified: 2021/08/28
Description: Overview of how to use the TensorFlow NumPy API to write Keras models.
Accelerator: GPU
"""
"""
## Introduction
[NumPy](https://numpy.org/) is a hugely successful Python linear algebra library.
TensorFlow recently launched [tf_numpy](https://www.tensorflow.org/guide/tf_numpy), a
TensorFlow implementation of a large subset of the NumPy API.
Thanks to `tf_numpy`, you can write Keras layers or models in the NumPy style!
The TensorFlow NumPy API has full integration with the TensorFlow ecosystem.
Features such as automatic differentiation, TensorBoard, Keras model callbacks,
TPU distribution and model exporting are all supported.
Let's run through a few examples.
"""
"""
## Setup
TensorFlow NumPy requires TensorFlow 2.5 or later.
"""
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import keras
import keras.layers as layers
import numpy as np
"""
Optionally, you can call `tnp.experimental_enable_numpy_behavior()` to enable type promotion in TensorFlow.
This allows TNP to more closely follow the NumPy standard.
"""
tnp.experimental_enable_numpy_behavior()
"""
To test our models we will use the Boston housing prices regression dataset.
"""
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data(
path="boston_housing.npz", test_split=0.2, seed=113
)
def evaluate_model(model: keras.Model):
[loss, percent_error] = model.evaluate(x_test, y_test, verbose=0)
print("Mean absolute percent error before training: ", percent_error)
model.fit(x_train, y_train, epochs=200, verbose=0)
[loss, percent_error] = model.evaluate(x_test, y_test, verbose=0)
print("Mean absolute percent error after training:", percent_error)
"""
## Subclassing keras.Model with TNP
The most flexible way to make use of the Keras API is to subclass the
[`keras.Model`](https://keras.io/api/models/model/) class. Subclassing the Model class
gives you the ability to fully customize what occurs in the training loop. This makes
subclassing Model a popular option for researchers.
In this example, we will implement a `Model` subclass that performs regression over the
boston housing dataset using the TNP API. Note that differentiation and gradient
descent is handled automatically when using the TNP API alongside keras.
First let's define a simple `TNPForwardFeedRegressionNetwork` class.
"""
class TNPForwardFeedRegressionNetwork(keras.Model):
def __init__(self, blocks=None, **kwargs):
super().__init__(**kwargs)
if not isinstance(blocks, list):
raise ValueError(f"blocks must be a list, got blocks={blocks}")
self.blocks = blocks
self.block_weights = None
self.biases = None
def build(self, input_shape):
current_shape = input_shape[1]
self.block_weights = []
self.biases = []
for i, block in enumerate(self.blocks):
self.block_weights.append(
self.add_weight(
shape=(current_shape, block), trainable=True, name=f"block-{i}"
)
)
self.biases.append(
self.add_weight(shape=(block,), trainable=True, name=f"bias-{i}")
)
current_shape = block
self.linear_layer = self.add_weight(
shape=(current_shape, 1), name="linear_projector", trainable=True
)
def call(self, inputs):
activations = inputs
for w, b in zip(self.block_weights, self.biases):
activations = tnp.matmul(activations, w) + b
# ReLu activation function
activations = tnp.maximum(activations, 0.0)
return tnp.matmul(activations, self.linear_layer)
"""
Just like with any other Keras model we can utilize any supported optimizer, loss,
metrics or callbacks that we want.
Let's see how the model performs!
"""
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
evaluate_model(model)
"""
Great! Our model seems to be effectively learning to solve the problem at hand.
We can also write our own custom loss function using TNP.
"""
def tnp_mse(y_true, y_pred):
return tnp.mean(tnp.square(y_true - y_pred), axis=0)
keras.backend.clear_session()
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
model.compile(
optimizer="adam",
loss=tnp_mse,
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
evaluate_model(model)
"""
## Implementing a Keras Layer Based Model with TNP
If desired, TNP can also be used in layer oriented Keras code structure. Let's
implement the same model, but using a layered approach!
"""
def tnp_relu(x):
return tnp.maximum(x, 0)
class TNPDense(keras.layers.Layer):
def __init__(self, units, activation=None):
super().__init__()
self.units = units
self.activation = activation
def build(self, input_shape):
self.w = self.add_weight(
name="weights",
shape=(input_shape[1], self.units),
initializer="random_normal",
trainable=True,
)
self.bias = self.add_weight(
name="bias",
shape=(self.units,),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
outputs = tnp.matmul(inputs, self.w) + self.bias
if self.activation:
return self.activation(outputs)
return outputs
def create_layered_tnp_model():
return keras.Sequential(
[
TNPDense(3, activation=tnp_relu),
TNPDense(3, activation=tnp_relu),
TNPDense(1),
]
)
model = create_layered_tnp_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build(
(
None,
13,
)
)
model.summary()
evaluate_model(model)
"""
You can also seamlessly switch between TNP layers and native Keras layers!
"""
def create_mixed_model():
return keras.Sequential(
[
TNPDense(3, activation=tnp_relu),
# The model will have no issue using a normal Dense layer
layers.Dense(3, activation="relu"),
# ... or switching back to tnp layers!
TNPDense(1),
]
)
model = create_mixed_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build(
(
None,
13,
)
)
model.summary()
evaluate_model(model)
"""
The Keras API offers a wide variety of layers. The ability to use them alongside NumPy
code can be a huge time saver in projects.
"""
"""
## Distribution Strategy
TensorFlow NumPy and Keras integrate with
[TensorFlow Distribution Strategies](https://www.tensorflow.org/guide/distributed_training).
This makes it simple to perform distributed training across multiple GPUs,
or even an entire TPU Pod.
"""
gpus = tf.config.list_logical_devices("GPU")
if gpus:
strategy = tf.distribute.MirroredStrategy(gpus)
else:
# We can fallback to a no-op CPU strategy.
strategy = tf.distribute.get_strategy()
print("Running with strategy:", str(strategy.__class__.__name__))
with strategy.scope():
model = create_layered_tnp_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build(
(
None,
13,
)
)
model.summary()
evaluate_model(model)
"""
## TensorBoard Integration
One of the many benefits of using the Keras API is the ability to monitor training
through TensorBoard. Using the TensorFlow NumPy API alongside Keras allows you to easily
leverage TensorBoard.
"""
keras.backend.clear_session()
"""
To load the TensorBoard from a Jupyter notebook, you can run the following magic:
```
%load_ext tensorboard
```
"""
models = [
(TNPForwardFeedRegressionNetwork(blocks=[3, 3]), "TNPForwardFeedRegressionNetwork"),
(create_layered_tnp_model(), "layered_tnp_model"),
(create_mixed_model(), "mixed_model"),
]
for model, model_name in models:
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.fit(
x_train,
y_train,
epochs=200,
verbose=0,
callbacks=[keras.callbacks.TensorBoard(log_dir=f"logs/{model_name}")],
)
"""
To load the TensorBoard from a Jupyter notebook you can use the `%tensorboard` magic:
```
%tensorboard --logdir logs
```
The TensorBoard monitor metrics and examine the training curve.
![Tensorboard training graph](https://i.imgur.com/wsOuFnz.png)
The TensorBoard also allows you to explore the computation graph used in your models.
![Tensorboard graph exploration](https://i.imgur.com/tOrezDL.png)
The ability to introspect into your models can be valuable during debugging.
"""
"""
## Conclusion
Porting existing NumPy code to Keras models using the `tensorflow_numpy` API is easy!
By integrating with Keras you gain the ability to use existing Keras callbacks, metrics
and optimizers, easily distribute your training and use Tensorboard.
Migrating a more complex model, such as a ResNet, to the TensorFlow NumPy API would be a
great follow up learning exercise.
Several open source NumPy ResNet implementations are available online.
"""