Skip to content

Commit

Permalink
Adds symbolic diff for THNN Conv2d and aten native BatchNorm (pytorch…
Browse files Browse the repository at this point in the history
…#13888)

Summary:
Adds symbolic diff and tests.
Pull Request resolved: pytorch#13888

Differential Revision: D13115548

Pulled By: soumith

fbshipit-source-id: ba75b01a95a5715a7761724dda018168b6188917
  • Loading branch information
soumith authored and facebook-github-bot committed Nov 18, 2018
1 parent 07a8a73 commit ef3d796
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 34 deletions.
3 changes: 3 additions & 0 deletions test/cpp/jit/gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ JIT_TEST(TopologicalIndex)
JIT_TEST(TopologicalMove)
JIT_TEST(SubgraphUtils)

JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)

#define JIT_TEST_CUDA(name) \
TEST(JitTest, name##_CUDA) { \
test##name(); \
Expand Down
248 changes: 215 additions & 33 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
Expand Down Expand Up @@ -446,6 +447,39 @@ void run(InterpreterState & interp, const std::vector<at::Tensor> & inputs, std:
}
}

std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in) {
tensor_list tensors_out, tensor_grads_out;
Code f_code{grad_spec.f}, df_code{grad_spec.df};
InterpreterState f_interpreter{f_code}, df_interpreter{df_code};

run(f_interpreter, tensors_in, tensors_out);

tensor_list df_inputs;
df_inputs.insert(
df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end());
for (auto offset : grad_spec.df_input_captured_inputs)
df_inputs.push_back(tensors_in[offset]);
for (auto offset : grad_spec.df_input_captured_outputs)
df_inputs.push_back(tensors_out[offset]);
run(df_interpreter, df_inputs, tensor_grads_out);

// Outputs of f needs to be sliced
tensors_out.erase(
tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end());
return std::make_pair(tensors_out, tensor_grads_out);
}

void assertAllClose(const tensor_list& a, const tensor_list& b) {
ASSERT_EQ(a.size(), b.size());
for (size_t i = 0; i < a.size(); ++i) {
ASSERT_TRUE(a[i].is_same_size(b[i]));
ASSERT_TRUE(a[i].allclose(b[i]));
}
}

void testInterp() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
Expand All @@ -471,6 +505,187 @@ void testInterp() {
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}

void testTHNNConv() {
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
std::vector<int64_t> kernel_size = {3, 5};
std::vector<int64_t> stride = {1, 2};
std::vector<int64_t> padding = {2, 1};
constexpr int out_channels = 5;

// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]});
at::Tensor bias = torch::randn({out_channels});

// run forward eagerly
at::Tensor output, finput, fgradinput;
std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size,
bias, stride, padding);

// make grad_outputs
at::Tensor grad_output = torch::randn_like(output);
at::Tensor grad_finput = torch::zeros_like(finput);
at::Tensor grad_fgradinput = torch::zeros_like(fgradinput);

// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight,
kernel_size, stride, padding,
finput, fgradinput, {true, true, true});

// make JIT graph
auto graph = std::make_shared<Graph>();
auto ksz_val = graph->insertConstant(IValue(kernel_size));
auto kst_val = graph->insertConstant(IValue(stride));
auto pad_val = graph->insertConstant(IValue(padding));

auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");

Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
auto outputs = conv->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();

// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);

// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);

tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_finput);
tensor_grads_in.push_back(grad_fgradinput);

// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);

// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(finput);
expected_tensors_out.push_back(fgradinput);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);

// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}

void testATenNativeBatchNorm() {
// aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
bool training = true;
float momentum = 0.9;
float eps = 1e-5;

// make inputs
at::Tensor input = torch::randn(input_size);
at::Tensor weight = torch::randn({input_size[1]});
at::Tensor bias = torch::randn({input_size[1]});
at::Tensor running_mean = torch::randn({input_size[1]});
at::Tensor running_var = torch::randn({input_size[1]});

// running_mean and running_var are changed in-place, so clone and send them
at::Tensor running_mean_eager = running_mean.clone();
at::Tensor running_var_eager = running_var.clone();
at::Tensor running_mean_jit = running_mean.clone();
at::Tensor running_var_jit = running_var.clone();

// run forward eagerly
at::Tensor output, savemean, saveinvstd;
std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps);

// make grad_outputs
at::Tensor grad_output = torch::randn_like(output);
at::Tensor grad_savemean = torch::zeros_like(savemean);
at::Tensor grad_saveinvstd = torch::zeros_like(saveinvstd);

// run backward eagerly
at::Tensor grad_input, grad_weight, grad_bias;
// aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight,
running_mean_eager, running_var_eager,
savemean, saveinvstd, training, eps, {true, true, true});

// make JIT graph
auto graph = std::make_shared<Graph>();
auto training_val = graph->insertConstant(IValue(training));
auto momentum_val = graph->insertConstant(IValue(momentum));
auto eps_val = graph->insertConstant(IValue(eps));

auto inputg = graph->addInput("self");
auto weightg = graph->addInput("weight");
auto biasg = graph->addInput("bias");
auto running_meang = graph->addInput("running_mean");
auto running_varg = graph->addInput("running_var");

Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val});
auto outputs = bn->node()->outputs();
for (auto output : outputs) {
graph->registerOutput(output);
}
LowerAllTuples(graph);
graph->lint();

// differentiate JIT graph
EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
ConstantPropagation(graph);
auto grad_spec = differentiate(graph);
LowerGradOf(*grad_spec.df);

// prepare JIT inputs / gradients
tensor_list tensors_in;
tensors_in.push_back(input);
tensors_in.push_back(weight);
tensors_in.push_back(bias);
tensors_in.push_back(running_mean_jit);
tensors_in.push_back(running_var_jit);

tensor_list tensor_grads_in;
tensor_grads_in.push_back(grad_output);
tensor_grads_in.push_back(grad_savemean);
tensor_grads_in.push_back(grad_saveinvstd);

// Get outputs from the interpreter
tensor_list tensors_out, tensor_grads_out;
std::tie(tensors_out, tensor_grads_out) =
runGradient(grad_spec, tensors_in, tensor_grads_in);

// prepare expected structs
tensor_list expected_tensors_out, expected_tensor_grads_out;
expected_tensors_out.push_back(output);
expected_tensors_out.push_back(savemean);
expected_tensors_out.push_back(saveinvstd);
expected_tensors_out.push_back(running_mean_eager);
expected_tensors_out.push_back(running_var_eager);
expected_tensor_grads_out.push_back(grad_input);
expected_tensor_grads_out.push_back(grad_weight);
expected_tensor_grads_out.push_back(grad_bias);

tensors_out.push_back(running_mean_jit);
tensors_out.push_back(running_var_jit);

// Compare results
assertAllClose(tensors_out, expected_tensors_out);
assertAllClose(tensor_grads_out, expected_tensor_grads_out);
}

using var_meta_type = std::vector<int64_t>;
using var_meta_list = std::vector<var_meta_type>;
using test_fn_type = std::function<variable_list(const variable_list&)>;
Expand Down Expand Up @@ -529,39 +744,6 @@ variable_list grad(
fmap(inputs, get_edge));
}

void assertAllClose(const tensor_list& a, const tensor_list& b) {
ASSERT_EQ(a.size(), b.size());
for (size_t i = 0; i < a.size(); ++i) {
ASSERT_TRUE(a[i].is_same_size(b[i]));
ASSERT_TRUE(a[i].allclose(b[i]));
}
}

std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in) {
tensor_list tensors_out, tensor_grads_out;
Code f_code{grad_spec.f}, df_code{grad_spec.df};
InterpreterState f_interpreter{f_code}, df_interpreter{df_code};

run(f_interpreter, tensors_in, tensors_out);

tensor_list df_inputs;
df_inputs.insert(
df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end());
for (auto offset : grad_spec.df_input_captured_inputs)
df_inputs.push_back(tensors_in[offset]);
for (auto offset : grad_spec.df_input_captured_outputs)
df_inputs.push_back(tensors_out[offset]);
run(df_interpreter, df_inputs, tensor_grads_out);

// Outputs of f needs to be sliced
tensors_out.erase(
tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end());
return std::make_pair(tensors_out, tensor_grads_out);
}

void testADFormulas() {
const auto unwrap = [](const Variable& v) { return v.data(); };

Expand Down
43 changes: 42 additions & 1 deletion torch/csrc/jit/autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ bool isDifferentiable(Node * n) {
"aten::trunc(Tensor self) -> Tensor",
"aten::log_softmax(Tensor self, int dim) -> Tensor",
"aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
"aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)"
"aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
"aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
};

// TODO: add support for the following fusible operators.
Expand Down Expand Up @@ -401,6 +403,45 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
});
return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};

} else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
auto backward_value = graph->insert(aten::thnn_conv2d_backward, {
grads.at(0).value(),
inputs.at(0).value(),
inputs.at(1).value(),
node->namedInput(attr::kernel_size),
node->namedInput(attr::stride),
node->namedInput(attr::padding),
outputs.at(1).value(),
outputs.at(2).value(),
graph->insertConstant(std::vector<bool>{true, true, true})
});
// graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
auto tuple_outputs = tuple_unpack_node->outputs();
JIT_ASSERT(tuple_outputs.size() == size_t(3));
return {tuple_outputs[0], tuple_outputs[1], nullptr, tuple_outputs[2], nullptr, nullptr};

} else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
auto backward_value = graph->insert(aten::native_batch_norm_backward, {
grads.at(0).value(),
inputs.at(0).value(),
inputs.at(1).value(),
inputs.at(3).value(),
inputs.at(4).value(),
outputs.at(1).value(),
outputs.at(2).value(),
inputs.at(5).value(),
inputs.at(7).value(),
graph->insertConstant(std::vector<bool>{true, true, true})
});
// graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
auto tuple_outputs = tuple_unpack_node->outputs();
JIT_ASSERT(tuple_outputs.size() == size_t(3));
return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr};

} else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
JIT_ASSERT(grads.size() == 1);
auto graph = node->owningGraph();
Expand Down

0 comments on commit ef3d796

Please sign in to comment.