Skip to content

Commit

Permalink
Merge pull request #2146 from siyerp/patch-3
Browse files Browse the repository at this point in the history
Updated function signatures to comply with new tensorflow requirements
  • Loading branch information
miguelgfierro authored Aug 15, 2024
2 parents 34b7b09 + 551ec31 commit 61568e6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions recommenders/models/sasrec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def call(self, x, training, mask):
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
training (Boolean): True if in training mode.
mask (tf.Tensor): Mask tensor.
Returns:
Expand Down Expand Up @@ -305,15 +305,15 @@ def call(self, x, training, mask):
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
training (Boolean): True if in training mode.
mask (tf.Tensor): Mask tensor.
Returns:
tf.Tensor: Output tensor.
"""

for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
x = self.enc_layers[i](x, training=training, mask=mask)

return x # (batch_size, input_seq_len, d_model)

Expand Down
4 changes: 2 additions & 2 deletions recommenders/models/sasrec/ssept.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def call(self, x, training):
# --- ATTENTION BLOCKS ---
seq_attention = seq_embeddings # (b, s, h1 + h2)

seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.encoder(seq_attention, training=training, mask=mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)

# --- PREDICTION LAYER ---
Expand Down Expand Up @@ -197,7 +197,7 @@ def predict(self, inputs):

seq_embeddings *= mask
seq_attention = seq_embeddings
seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.encoder(seq_attention, training=training, mask=mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)
seq_emb = tf.reshape(
seq_attention,
Expand Down

0 comments on commit 61568e6

Please sign in to comment.