Skip to content

Commit

Permalink
CNTK v2 library: Add support for executing eval before grad.
Browse files Browse the repository at this point in the history
  • Loading branch information
amitaga committed May 16, 2017
1 parent be99d3e commit 9d19a25
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
5 changes: 0 additions & 5 deletions Source/CNTKv2LibraryDll/API/CNTKLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -3294,11 +3294,6 @@ namespace CNTK
CNTK_API static FunctionPtr Load(std::istream& inputStream,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());

///
/// Prints the entire graph underlying this Function to stderr
///
CNTK_API void PrintGraph() const;

///
/// Returns a string representation of this Function
///
Expand Down
5 changes: 5 additions & 0 deletions Source/CNTKv2LibraryDll/CompositeFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,11 @@ namespace CNTK
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool allocateNetworkMatrices)
{
// Lets purge the current computation network and regenerate the network if the CompositeFunction
// was previously compiled just for evaluation and not for gradient backpropagation.
if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty()))
PurgeComputationNetwork();

if (m_computationNetwork != nullptr)
{
// TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network
Expand Down
12 changes: 12 additions & 0 deletions Source/CNTKv2LibraryDll/CompositeFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,18 @@ namespace CNTK
m_existingNetworkStorageReferences.clear();
}

void PurgeComputationNetwork()
{
m_currentBackpropRoots.clear();
m_inputsExcludedFromGradientComputation.clear();
m_variableToNodeMap.clear();
m_currentOutputsToEvaluate.clear();
m_lastRecordedTimeStamps.clear();

m_networkMatricesAllocated = false;
m_computationNetwork = nullptr;
}

private:

// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
Expand Down
6 changes: 0 additions & 6 deletions Source/CNTKv2LibraryDll/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,12 +925,6 @@ namespace CNTK
return CompositeFunction::Deserialize(modelDictionary, device);
}

void Function::PrintGraph() const
{
CompositeFunction::PreorderTraverseFunctions(RootFunction(), [](const FunctionPtr& function) {
});
}

std::wstring Function::AsString(bool doNotInferOutputs) const
{
wstringstream wss;
Expand Down
22 changes: 22 additions & 0 deletions bindings/python/cntk/ops/tests/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,25 @@ def test_input_without_dynamic_axes():
assert np.allclose(eval_result, [3.006, 2.992])
assert np.allclose(grad_result, [.01, .01])


def test_grad_after_eval():
x = C.input_variable((C.FreeDimension, 2))
w = C.parameter(init=np.asarray([[2, 5], [1, 3]], dtype=np.float32))
t = C.times(x, w)

x_data = np.asarray([[0.5, 0.2]], np.float32)
t_val = t.eval({x : x_data})
assert np.array_equal(t_val, np.asarray([[[1.2, 3.1]]], dtype=np.float32))

w_grad, t_val = t.grad({x : x_data}, wrt=[w], outputs=[t])
assert np.array_equal(t_val, np.asarray([[[1.2, 3.1]]], dtype=np.float32))
assert np.array_equal(w_grad, np.asarray([[0.5, .5], [.2, .2]], dtype=np.float32))

x_data = np.asarray([[0.5, 0.2], [0.1, .6]], np.float32)
t_val = t.eval({x : x_data})
assert np.allclose(t_val, np.asarray([[[1.2, 3.1], [0.8, 2.3]]], dtype=np.float32))

w_grad, t_val = t.grad({x : x_data}, wrt=[w], outputs=[t])
assert np.allclose(t_val, np.asarray([[[1.2, 3.1], [0.8, 2.3]]], dtype=np.float32))
assert np.array_equal(w_grad, np.asarray([[0.6, .6], [.8, .8]], dtype=np.float32))

0 comments on commit 9d19a25

Please sign in to comment.