Skip to content

Commit

Permalink
Fixed preprocessing and training scripts in Quick-start for Ranking (N…
Browse files Browse the repository at this point in the history
…VIDIA-Merlin#1017)

* Fixed preprocessing, which was not standardizing and tagging continuous columns properly and fixing W&D and DeepFM models to use the updated API

* Updated example command for ranking.py

* Moving indexing after filter
  • Loading branch information
gabrielspmoreira authored Jun 20, 2023
1 parent e131376 commit 8c9fbff
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
6 changes: 3 additions & 3 deletions examples/quick_start/scripts/preproc/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def read_data(self, path):
logging.info("First lines...")
logging.info(ddf.head())

ddf = self.adding_temp_index(ddf)

logging.info(f"Number of rows: {len(ddf)}")
if args.filter_query:
logging.info(f"Filtering rows using filter {args.filter_query}")
ddf = ddf.query(args.filter_query)
logging.info(f"Number of rows after filtering: {len(ddf)}")

ddf = self.adding_temp_index(ddf)

return ddf

def cast_dtypes(self, ddf):
Expand Down Expand Up @@ -232,7 +232,7 @@ def generate_nvt_features(self):
feats[col] = feats[col] >> nvt_ops.FillMissing(
args.continuous_features_fillna
)
feats[col] = feats[col] >> nvt_ops.Normalize()
feats[col] = feats[col] >> nvt_ops.Normalize()

if args.target_encoding_features or args.target_encoding_targets:
if not args.target_encoding_features:
Expand Down
5 changes: 3 additions & 2 deletions examples/quick_start/scripts/ranking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ This is an example command line for running the training for the TenRec dataset


```bash
cd /Merlin/examples/quick_start/scripts/ranking/
cd /Merlin/examples/
OUT_DATASET_PATH=/outputs/dataset
CUDA_VISIBLE_DEVICES=0 TF_GPU_ALLOCATOR=cuda_malloc_async python ranking.py --train_data_path $OUT_DATASET_PATH/train --eval_data_path $OUT_DATASET_PATH/eval --output_path ./outputs/ --tasks=click --stl_positive_class_weight 3 --model dlrm --embeddings_dim 64 --l2_reg 1e-4 --embeddings_l2_reg 1e-6 --dropout 0.05 --mlp_layers 64,32 --lr 1e-4 --lr_decay_rate 0.99 --lr_decay_steps 100 --train_batch_size 65536 --eval_batch_size 65536 --epochs 1 --save_model_path ./saved_model

CUDA_VISIBLE_DEVICES=0 TF_GPU_ALLOCATOR=cuda_malloc_async python -m quick_start.scripts.ranking.ranking --train_data_path $OUT_DATASET_PATH/train --eval_data_path $OUT_DATASET_PATH/eval --output_path ./outputs/ --tasks=click --stl_positive_class_weight 3 --model dlrm --embeddings_dim 64 --l2_reg 1e-4 --embeddings_l2_reg 1e-6 --dropout 0.05 --mlp_layers 64,32 --lr 1e-4 --lr_decay_rate 0.99 --lr_decay_steps 100 --train_batch_size 65536 --eval_batch_size 65536 --epochs 1 --save_model_path ./saved_model
```

### Inputs
Expand Down
29 changes: 19 additions & 10 deletions examples/quick_start/scripts/ranking/ranking_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def get_mlp_model(schema, args, prediction_tasks):
cat_schema,
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
),
aggregation="concat",
Expand Down Expand Up @@ -74,7 +75,8 @@ def get_dcn_model(schema, args, prediction_tasks):
schema.select_by_tag(Tags.CATEGORICAL),
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
),
aggregation="concat",
Expand Down Expand Up @@ -158,7 +160,7 @@ def get_deepfm_model(schema, args, prediction_tasks):
if len(cat_schema_multihot) > 0:
wide_inputs_block["categorical_mhe"] = mm.SequentialBlock(
mm.Filter(cat_schema_multihot),
mm.ListToDense(max_seq_length=args.multihot_max_seq_length),
mm.ToDense(cat_schema_multihot),
mm.CategoryEncoding(
cat_schema_multihot, sparse=True, output_mode="multi_hot"
),
Expand Down Expand Up @@ -195,7 +197,8 @@ def get_wide_and_deep_model(schema, args, prediction_tasks):
cat_schema,
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
)

Expand All @@ -212,7 +215,7 @@ def get_wide_and_deep_model(schema, args, prediction_tasks):
# 2nd level feature interactions of multi-hot features
mm.SequentialBlock(
mm.Filter(cat_schema.remove_by_tag(Tags.USER_ID)),
mm.ListToDense(max_seq_length=args.multihot_max_seq_length),
mm.ToDense(cat_schema.remove_by_tag(Tags.USER_ID)),
mm.HashedCrossAll(
cat_schema.remove_by_tag(Tags.USER_ID),
num_bins=args.wnd_hashed_cross_num_bins,
Expand All @@ -227,7 +230,7 @@ def get_wide_and_deep_model(schema, args, prediction_tasks):
wide_preprocess.append(
mm.SequentialBlock(
mm.Filter(cat_schema_multihot),
mm.ListToDense(max_seq_length=args.multihot_max_seq_length),
mm.ToDense(cat_schema_multihot),
mm.CategoryEncoding(
cat_schema_multihot, sparse=True, output_mode="multi_hot"
),
Expand All @@ -248,7 +251,10 @@ def get_wide_and_deep_model(schema, args, prediction_tasks):
wide_regularizer=regularizers.l2(args.wnd_wide_l2_reg),
wide_dropout=args.dropout,
deep_dropout=args.dropout,
wide_preprocess=mm.ParallelBlock(wide_preprocess, aggregation="concat",),
wide_preprocess=mm.ParallelBlock(
wide_preprocess,
aggregation="concat",
),
prediction_tasks=prediction_tasks,
)

Expand All @@ -273,7 +279,8 @@ def get_mmoe_model(schema, args, prediction_tasks):
schema.select_by_tag(Tags.CATEGORICAL),
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
),
aggregation="concat",
Expand Down Expand Up @@ -310,7 +317,8 @@ def get_cgc_model(schema, args, prediction_tasks):
schema.select_by_tag(Tags.CATEGORICAL),
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
),
aggregation="concat",
Expand Down Expand Up @@ -348,7 +356,8 @@ def get_ple_model(schema, args, prediction_tasks):
schema.select_by_tag(Tags.CATEGORICAL),
embeddings_regularizer=regularizers.l2(args.embeddings_l2_reg),
infer_dim_fn=partial(
infer_embedding_dim, multiplier=args.embedding_sizes_multiplier,
infer_embedding_dim,
multiplier=args.embedding_sizes_multiplier,
),
),
aggregation="concat",
Expand Down

0 comments on commit 8c9fbff

Please sign in to comment.