diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 6fff556dd3d3..79dd5819b223 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -3294,11 +3294,6 @@ namespace CNTK CNTK_API static FunctionPtr Load(std::istream& inputStream, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice()); - /// - /// Prints the entire graph underlying this Function to stderr - /// - CNTK_API void PrintGraph() const; - /// /// Returns a string representation of this Function /// diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp index a3969d9aa893..a68d58933cc7 100755 --- a/Source/CNTKv2LibraryDll/CompositeFunction.cpp +++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp @@ -1317,6 +1317,11 @@ namespace CNTK const std::unordered_set& inputsToExcludeGradientsFor, bool allocateNetworkMatrices) { + // Lets purge the current computation network and regenerate the network if the CompositeFunction + // was previously compiled just for evaluation and not for gradient backpropagation. + if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty())) + PurgeComputationNetwork(); + if (m_computationNetwork != nullptr) { // TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.h b/Source/CNTKv2LibraryDll/CompositeFunction.h index d472dcd67e49..4d533f6f61a5 100644 --- a/Source/CNTKv2LibraryDll/CompositeFunction.h +++ b/Source/CNTKv2LibraryDll/CompositeFunction.h @@ -338,6 +338,18 @@ namespace CNTK m_existingNetworkStorageReferences.clear(); } + void PurgeComputationNetwork() + { + m_currentBackpropRoots.clear(); + m_inputsExcludedFromGradientComputation.clear(); + m_variableToNodeMap.clear(); + m_currentOutputsToEvaluate.clear(); + m_lastRecordedTimeStamps.clear(); + + m_networkMatricesAllocated = false; + m_computationNetwork = nullptr; + } + private: // Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 759b53329853..1eab88f99a8d 100755 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -925,12 +925,6 @@ namespace CNTK return CompositeFunction::Deserialize(modelDictionary, device); } - void Function::PrintGraph() const - { - CompositeFunction::PreorderTraverseFunctions(RootFunction(), [](const FunctionPtr& function) { - }); - } - std::wstring Function::AsString(bool doNotInferOutputs) const { wstringstream wss; diff --git a/bindings/python/cntk/ops/tests/evaluation_test.py b/bindings/python/cntk/ops/tests/evaluation_test.py index 2e29b08a87da..f357a86a9a8a 100644 --- a/bindings/python/cntk/ops/tests/evaluation_test.py +++ b/bindings/python/cntk/ops/tests/evaluation_test.py @@ -85,3 +85,25 @@ def test_input_without_dynamic_axes(): assert np.allclose(eval_result, [3.006, 2.992]) assert np.allclose(grad_result, [.01, .01]) + +def test_grad_after_eval(): + x = C.input_variable((C.FreeDimension, 2)) + w = C.parameter(init=np.asarray([[2, 5], [1, 3]], dtype=np.float32)) + t = C.times(x, w) + + x_data = np.asarray([[0.5, 0.2]], np.float32) + t_val = t.eval({x : x_data}) + assert np.array_equal(t_val, np.asarray([[[1.2, 3.1]]], dtype=np.float32)) + + w_grad, t_val = t.grad({x : x_data}, wrt=[w], outputs=[t]) + assert np.array_equal(t_val, np.asarray([[[1.2, 3.1]]], dtype=np.float32)) + assert np.array_equal(w_grad, np.asarray([[0.5, .5], [.2, .2]], dtype=np.float32)) + + x_data = np.asarray([[0.5, 0.2], [0.1, .6]], np.float32) + t_val = t.eval({x : x_data}) + assert np.allclose(t_val, np.asarray([[[1.2, 3.1], [0.8, 2.3]]], dtype=np.float32)) + + w_grad, t_val = t.grad({x : x_data}, wrt=[w], outputs=[t]) + assert np.allclose(t_val, np.asarray([[[1.2, 3.1], [0.8, 2.3]]], dtype=np.float32)) + assert np.array_equal(w_grad, np.asarray([[0.6, .6], [.8, .8]], dtype=np.float32)) + \ No newline at end of file