Skip to content

Commit 4c7bcd3

Browse files
authored
Fix deep_recommender KerasRS example. (#2114)
Rank of cross feature was incorrect.
1 parent 626114b commit 4c7bcd3

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

examples/keras_rs/deep_recommender.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,9 @@ def preprocess_rating(x):
433433
"raw_user_age": features["raw_user_age"],
434434
"user_gender": features["user_gender"],
435435
"user_occupation_label": features["user_occupation_label"],
436-
"user_gender_X_raw_user_age": features["user_gender_X_raw_user_age"],
436+
"user_gender_X_raw_user_age": tf.squeeze(
437+
features["user_gender_X_raw_user_age"], axis=-1
438+
),
437439
# Movie inputs are movie ID, vectorized title and genres
438440
"movie_id": int(x["movie_id"]),
439441
"movie_title_vector": features["movie_title"],

examples/keras_rs/ipynb/deep_recommender.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,9 @@
745745
" \"raw_user_age\": features[\"raw_user_age\"],\n",
746746
" \"user_gender\": features[\"user_gender\"],\n",
747747
" \"user_occupation_label\": features[\"user_occupation_label\"],\n",
748-
" \"user_gender_X_raw_user_age\": features[\"user_gender_X_raw_user_age\"],\n",
748+
" \"user_gender_X_raw_user_age\": tf.squeeze(\n",
749+
" features[\"user_gender_X_raw_user_age\"], axis=-1\n",
750+
" ),\n",
749751
" # Movie inputs are movie ID, vectorized title and genres\n",
750752
" \"movie_id\": int(x[\"movie_id\"]),\n",
751753
" \"movie_title_vector\": features[\"movie_title\"],\n",

examples/keras_rs/md/deep_recommender.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,9 @@ def preprocess_rating(x):
497497
"raw_user_age": features["raw_user_age"],
498498
"user_gender": features["user_gender"],
499499
"user_occupation_label": features["user_occupation_label"],
500-
"user_gender_X_raw_user_age": features["user_gender_X_raw_user_age"],
500+
"user_gender_X_raw_user_age": tf.squeeze(
501+
features["user_gender_X_raw_user_age"], axis=-1
502+
),
501503
# Movie inputs are movie ID, vectorized title and genres
502504
"movie_id": int(x["movie_id"]),
503505
"movie_title_vector": features["movie_title"],

0 commit comments

Comments
 (0)