Skip to content

Commit

Permalink
Quick-start for ranking - Added --shuffled_train arg to ranking.py sc…
Browse files Browse the repository at this point in the history
…ript (NVIDIA-Merlin#985)

* Added --shuffled_train arg to ranking.py script

* Making --train_metrics_steps 1 by default, setting train shuffling correctly, dropping last training batch, disabled validation metrics during training
  • Loading branch information
gabrielspmoreira authored May 19, 2023
1 parent 3aa13a6 commit d34f280
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
2 changes: 2 additions & 0 deletions examples/quick_start/scripts/ranking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ CUDA_VISIBLE_DEVICES=0 TF_GPU_ALLOCATOR=cuda_malloc_async python ranking.py --t
--train_steps_per_epoch
Number of train steps per epoch. Set this for
quick debugging.
--shuffled_train
Shuffles data during training.
```

### Logging
Expand Down
11 changes: 10 additions & 1 deletion examples/quick_start/scripts/ranking/args_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def build_arg_parser():

parser.add_argument(
"--train_metrics_steps",
default=10,
default=1,
type=int,
help="How often should train metrics be computed during training. "
"You might increase this number to reduce the frequency and increase a bit the "
Expand All @@ -462,6 +462,15 @@ def build_arg_parser():
help="Number of train steps per epoch. Set this for quick debugging.",
)

parser.add_argument(
"--shuffled_train",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Shuffles data during training.",
)

# In-batch negatives
parser.add_argument(
"--in_batch_negatives_train",
Expand Down
48 changes: 19 additions & 29 deletions examples/quick_start/scripts/ranking/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import merlin.models.tf as mm
import numpy as np
import tensorflow as tf
from args_parsing import Task, parse_arguments
from merlin.io.dataset import Dataset
from merlin.models.tf.logging.callbacks import ExamplesPerSecondCallback, WandbLogger
from merlin.models.tf.transforms.negative_sampling import InBatchNegatives
Expand All @@ -17,7 +16,6 @@
from ranking_models import get_model



def get_datasets(args):
train_ds = (
Dataset(os.path.join(args.train_data_path, "*.parquet"), part_size="500MB")
Expand Down Expand Up @@ -131,6 +129,7 @@ def set_dataloaders(self, train_schema, eval_schema):
self.train_ds,
batch_size=args.train_batch_size,
schema=train_schema,
shuffle=args.shuffled_train,
**train_loader_kwargs,
)

Expand All @@ -146,14 +145,14 @@ def set_dataloaders(self, train_schema, eval_schema):
self.eval_ds,
batch_size=args.eval_batch_size,
schema=eval_schema,
shuffle=False,
**eval_loader_kwargs,
)

self.predict_loader = None
if self.predict_ds:
self.predict_loader = mm.Loader(
self.predict_ds,
batch_size=args.eval_batch_size,
self.predict_ds, batch_size=args.eval_batch_size, shuffle=False,
)

def get_metrics(self):
Expand Down Expand Up @@ -198,13 +197,9 @@ def get_optimizer(self):
)

if self.args.optimizer == "adam":
opt = tf.keras.optimizers.Adam(
learning_rate=lerning_rate,
)
opt = tf.keras.optimizers.Adam(learning_rate=lerning_rate,)
elif self.args.optimizer == "adagrad":
opt = tf.keras.optimizers.legacy.Adagrad(
learning_rate=lerning_rate,
)
opt = tf.keras.optimizers.legacy.Adagrad(learning_rate=lerning_rate,)
else:
raise ValueError("Invalid optimizer")

Expand All @@ -228,9 +223,7 @@ def build_stl_model(self):
def train_eval_stl(self, model):
metrics = self.get_metrics()
model.compile(
self.get_optimizer(),
run_eagerly=False,
metrics=metrics,
self.get_optimizer(), run_eagerly=False, metrics=metrics,
)

callbacks = self.get_callbacks(self.args)
Expand All @@ -240,19 +233,18 @@ def train_eval_stl(self, model):
logging.info("Starting to train the model")

fit_kwargs = {}
if self.eval_loader:
fit_kwargs = {
"validation_data": self.eval_loader,
"validation_steps": self.args.validation_steps,
}
# if self.eval_loader:
# fit_kwargs = {
# "validation_data": self.eval_loader,
# "validation_steps": self.args.validation_steps,
# }

model.fit(
self.train_loader,
epochs=self.args.epochs,
batch_size=self.args.train_batch_size,
steps_per_epoch=self.args.train_steps_per_epoch,
shuffle=False,
drop_last=False,
drop_last=True,
callbacks=callbacks,
train_metrics_steps=self.args.train_metrics_steps,
class_weight=class_weights,
Expand Down Expand Up @@ -297,19 +289,18 @@ def train_eval_mtl(self, model):
if self.train_loader:
logging.info("Starting to train the model (fit())")
fit_kwargs = {}
if self.eval_loader:
fit_kwargs = {
"validation_data": self.eval_loader,
"validation_steps": self.args.validation_steps,
}
# if self.eval_loader:
# fit_kwargs = {
# "validation_data": self.eval_loader,
# "validation_steps": self.args.validation_steps,
# }

model.fit(
self.train_loader,
epochs=args.epochs,
batch_size=args.train_batch_size,
steps_per_epoch=args.train_steps_per_epoch,
shuffle=False,
drop_last=False,
drop_last=True,
callbacks=callbacks,
train_metrics_steps=args.train_metrics_steps,
**fit_kwargs,
Expand Down Expand Up @@ -356,8 +347,7 @@ def save_predictions(self, model, dataset):
logging.info("Starting the batch predict of the evaluation set")

predictions_ds = model.batch_predict(
dataset,
batch_size=self.args.eval_batch_size,
dataset, batch_size=self.args.eval_batch_size,
)
predictions_ddf = predictions_ds.to_ddf()

Expand Down

0 comments on commit d34f280

Please sign in to comment.