Skip to content

Commit

Permalink
Fix DFPTraining validation set option (#709)
Browse files Browse the repository at this point in the history
- Fix `validation_size` validation. Allow for `validation_size=0.0` for no validation set.
- Pass `run_validation` to dfencoder `fit()`. This option was recently added to `fit()` and is now required to enable use of a validation set and early stopping. 

Fixes #707
Fixes #708

Authors:
  - Eli Fajardo (https://github.com/efajardo-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)
  - David Gardner (https://github.com/dagardner-nv)

URL: #709
  • Loading branch information
efajardo-nv authored Feb 23, 2023
1 parent b8671c6 commit ccb0f4a
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, c: Config, model_kwargs: dict = None, epochs=30, validation_s

self._epochs = epochs

if (validation_size > 0.0 and validation_size < 1.0):
if (validation_size >= 0.0 and validation_size < 1.0):
self._validation_size = validation_size
else:
raise ValueError("validation_size={0} should be a positive float in the "
Expand Down Expand Up @@ -85,13 +85,15 @@ def on_data(self, message: MultiDFPMessage):
# Only train on the feature columns
train_df = train_df[train_df.columns.intersection(self._config.ae.feature_columns)]
validation_df = None
run_validation = False

# Split into training and validation sets
if self._validation_size > 0.0:
train_df, validation_df = train_test_split(train_df, test_size=self._validation_size, shuffle=False)
run_validation = True

logger.debug("Training AE model for user: '%s'...", user_id)
model.fit(train_df, epochs=self._epochs, val=validation_df)
model.fit(train_df, epochs=self._epochs, val=validation_df, run_validation=run_validation)
logger.debug("Training AE model for user: '%s'... Complete.", user_id)

output_message = MultiAEMessage(message.meta,
Expand Down

0 comments on commit ccb0f4a

Please sign in to comment.