-
Notifications
You must be signed in to change notification settings - Fork 8
/
wide_and_deep.py
77 lines (63 loc) · 2.68 KB
/
wide_and_deep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import Dict, Text
import tensorflow as tf
import tensorflow_recommenders as tfrs
from trainer.models.common.basic_layers import MLPLayer
from trainer.util.tools import ObjectDict
class WideAndDeepTFRS(tfrs.Model):
def train_step(self, inputs):
with tf.GradientTape() as tape:
loss = self.compute_loss(inputs, training=True)
# Handle regularization losses as well.
regularization_loss = sum(self.losses)
total_loss = loss + regularization_loss
linear_vars = self.wide.trainable_variables
dnn_vars = self.deep.trainable_variables
linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
linear_optimizer = self.optimizer[0]
dnn_optimizer = self.optimizer[1]
linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars))
metrics = {metric.name: metric.result() for metric in self.metrics}
metrics["loss"] = loss
metrics["regularization_loss"] = regularization_loss
metrics["total_loss"] = total_loss
return metrics
class WideAndDeep(WideAndDeepTFRS):
def __init__(
self, hparams: ObjectDict, deep_emb: tf.keras.Model, wide_emb: tf.keras.Model
):
super().__init__()
self.deep_emb = deep_emb
self.wide_emb = wide_emb
self.hparams = hparams
self.task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryCrossentropy(), tf.keras.metrics.AUC()],
)
self.wide = tf.keras.experimental.LinearModel(
kernel_regularizer=tf.keras.regularizers.l2(l2=0.0001)
)
self.deep = MLPLayer()
self.prediction = tf.keras.layers.Dense(
1,
activation="sigmoid",
kernel_regularizer=tf.keras.regularizers.l2(l2=0.0001),
)
def call(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
return self.prediction(
self.deep(self.deep_emb(features), training=training)
+ self.wide(self.wide_emb(features), training=training)
)
def compute_loss(
self, features: Dict[Text, tf.Tensor], training=False
) -> tf.Tensor:
labels = tf.expand_dims(
tf.where(features[self.hparams.label] > 3, 1, 0), axis=-1
)
rating_predictions = self(features, training=training)
# The task computes the loss and the metrics.
return self.task(
labels=labels,
predictions=rating_predictions,
training=training,
)