Skip to content

Commit

Permalink
fix: Added support for training a multi-input model using a dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
ASolomatin committed Jul 1, 2024
1 parent f8b7bde commit 93dda17
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
14 changes: 13 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,19 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
Steps = data_handler.Inferredsteps
});

return evaluate(data_handler, callbacks, is_val, test_function);
Func<DataHandler, OwnedIterator, Dictionary<string, float>> testFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
testFunction = test_step_multi_inputs_function;
}
else
{
testFunction = test_function;
}

return evaluate(data_handler, callbacks, is_val, testFunction);
}

/// <summary>
Expand Down
13 changes: 12 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ public ICallback fit(IDatasetV2 dataset,
StepsPerExecution = _steps_per_execution
});

Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
trainStepFunction = train_step_multi_inputs_function;
}
else
{
trainStepFunction = train_step_function;
}

return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
train_step_func: trainStepFunction);
}

History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Expand Down

0 comments on commit 93dda17

Please sign in to comment.