diff --git a/src/Microsoft.ML.Dnn/DnnUtils.cs b/src/Microsoft.ML.Dnn/DnnUtils.cs index c46e0b4c74..cf7f3224cb 100644 --- a/src/Microsoft.ML.Dnn/DnnUtils.cs +++ b/src/Microsoft.ML.Dnn/DnnUtils.cs @@ -399,6 +399,11 @@ public Runner AddInput(Tensor value, int index) return this; } + public List GetInputValues() + { + return _inputValues; + } + public Runner AddOutputs(string output) { _outputs.Add(ParseOutput(output)); diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index ebccbcb851..902356d527 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -463,6 +463,11 @@ private void Dispose(bool disposing) { Session.close(); // invoked Dispose() } + + if (Session != null && Session.graph != IntPtr.Zero) + { + Session.graph.Dispose(); + } } finally { @@ -645,7 +650,7 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe Runner runner = new Runner(_parent.Session); // Feed inputs to the graph. - for (int i = 0; i < _parent.Inputs.Length; i++) + for (int i = 0; i < _parent.Inputs.Length; i++) { var tensor = srcTensorGetters[i].GetTensor(); runner.AddInput(_parent.Inputs[i], tensor); @@ -658,6 +663,12 @@ private void UpdateCacheIfNeeded(long position, ITensorValueGetter[] srcTensorGe // Execute the graph. var tensors = runner.Run(); + List inputTensors = runner.GetInputValues(); + foreach (Tensor inputTensor in inputTensors) + { + inputTensor.Dispose(); + } + Contracts.Assert(tensors.Length > 0); for (int j = 0; j < activeOutputColNames.Length; j++)