From ef3d7963d8128714ca3481500a8b425080886ba0 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Sun, 18 Nov 2018 09:20:29 -0800 Subject: [PATCH] Adds symbolic diff for THNN Conv2d and aten native BatchNorm (#13888) Summary: Adds symbolic diff and tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13888 Differential Revision: D13115548 Pulled By: soumith fbshipit-source-id: ba75b01a95a5715a7761724dda018168b6188917 --- test/cpp/jit/gtest.cpp | 3 + test/cpp/jit/tests.h | 248 +++++++++++++++++++++++++++++++----- torch/csrc/jit/autodiff.cpp | 43 ++++++- 3 files changed, 260 insertions(+), 34 deletions(-) diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index 27f9f14a78860..f799008d498ad 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -28,6 +28,9 @@ JIT_TEST(TopologicalIndex) JIT_TEST(TopologicalMove) JIT_TEST(SubgraphUtils) +JIT_TEST(THNNConv) +JIT_TEST(ATenNativeBatchNorm) + #define JIT_TEST_CUDA(name) \ TEST(JitTest, name##_CUDA) { \ test##name(); \ diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 92caf14299c7e..1556b30b4855c 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -45,6 +45,7 @@ #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/lower_grad_of.h" +#include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/requires_grad_analysis.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/passes/utils/subgraph_utils.h" @@ -446,6 +447,39 @@ void run(InterpreterState & interp, const std::vector & inputs, std: } } +std::pair runGradient( + Gradient& grad_spec, + tensor_list& tensors_in, + tensor_list& tensor_grads_in) { + tensor_list tensors_out, tensor_grads_out; + Code f_code{grad_spec.f}, df_code{grad_spec.df}; + InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; + + run(f_interpreter, tensors_in, tensors_out); + + tensor_list df_inputs; + df_inputs.insert( + df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end()); + for (auto offset : grad_spec.df_input_captured_inputs) + df_inputs.push_back(tensors_in[offset]); + for (auto offset : grad_spec.df_input_captured_outputs) + df_inputs.push_back(tensors_out[offset]); + run(df_interpreter, df_inputs, tensor_grads_out); + + // Outputs of f needs to be sliced + tensors_out.erase( + tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end()); + return std::make_pair(tensors_out, tensor_grads_out); +} + +void assertAllClose(const tensor_list& a, const tensor_list& b) { + ASSERT_EQ(a.size(), b.size()); + for (size_t i = 0; i < a.size(); ++i) { + ASSERT_TRUE(a[i].is_same_size(b[i])); + ASSERT_TRUE(a[i].allclose(b[i])); + } +} + void testInterp() { constexpr int batch_size = 4; constexpr int input_size = 256; @@ -471,6 +505,187 @@ void testInterp() { ASSERT_TRUE(exactlyEqual(outputs[1], cx)); } +void testTHNNConv() { + std::vector input_size = {4, 3, 15, 17}; // B x C x H x W + std::vector kernel_size = {3, 5}; + std::vector stride = {1, 2}; + std::vector padding = {2, 1}; + constexpr int out_channels = 5; + + // make inputs + at::Tensor input = torch::randn(input_size); + at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]}); + at::Tensor bias = torch::randn({out_channels}); + + // run forward eagerly + at::Tensor output, finput, fgradinput; + std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size, + bias, stride, padding); + + // make grad_outputs + at::Tensor grad_output = torch::randn_like(output); + at::Tensor grad_finput = torch::zeros_like(finput); + at::Tensor grad_fgradinput = torch::zeros_like(fgradinput); + + // run backward eagerly + at::Tensor grad_input, grad_weight, grad_bias; + std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight, + kernel_size, stride, padding, + finput, fgradinput, {true, true, true}); + + // make JIT graph + auto graph = std::make_shared(); + auto ksz_val = graph->insertConstant(IValue(kernel_size)); + auto kst_val = graph->insertConstant(IValue(stride)); + auto pad_val = graph->insertConstant(IValue(padding)); + + auto inputg = graph->addInput("self"); + auto weightg = graph->addInput("weight"); + auto biasg = graph->addInput("bias"); + + Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); + auto outputs = conv->node()->outputs(); + for (auto output : outputs) { + graph->registerOutput(output); + } + LowerAllTuples(graph); + graph->lint(); + + // differentiate JIT graph + EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick + ConstantPropagation(graph); + auto grad_spec = differentiate(graph); + LowerGradOf(*grad_spec.df); + + // prepare JIT inputs / gradients + tensor_list tensors_in; + tensors_in.push_back(input); + tensors_in.push_back(weight); + tensors_in.push_back(bias); + + tensor_list tensor_grads_in; + tensor_grads_in.push_back(grad_output); + tensor_grads_in.push_back(grad_finput); + tensor_grads_in.push_back(grad_fgradinput); + + // Get outputs from the interpreter + tensor_list tensors_out, tensor_grads_out; + std::tie(tensors_out, tensor_grads_out) = + runGradient(grad_spec, tensors_in, tensor_grads_in); + + // prepare expected structs + tensor_list expected_tensors_out, expected_tensor_grads_out; + expected_tensors_out.push_back(output); + expected_tensors_out.push_back(finput); + expected_tensors_out.push_back(fgradinput); + expected_tensor_grads_out.push_back(grad_input); + expected_tensor_grads_out.push_back(grad_weight); + expected_tensor_grads_out.push_back(grad_bias); + + // Compare results + assertAllClose(tensors_out, expected_tensors_out); + assertAllClose(tensor_grads_out, expected_tensor_grads_out); +} + +void testATenNativeBatchNorm() { + // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + std::vector input_size = {4, 3, 15, 17}; // B x C x H x W + bool training = true; + float momentum = 0.9; + float eps = 1e-5; + + // make inputs + at::Tensor input = torch::randn(input_size); + at::Tensor weight = torch::randn({input_size[1]}); + at::Tensor bias = torch::randn({input_size[1]}); + at::Tensor running_mean = torch::randn({input_size[1]}); + at::Tensor running_var = torch::randn({input_size[1]}); + + // running_mean and running_var are changed in-place, so clone and send them + at::Tensor running_mean_eager = running_mean.clone(); + at::Tensor running_var_eager = running_var.clone(); + at::Tensor running_mean_jit = running_mean.clone(); + at::Tensor running_var_jit = running_var.clone(); + + // run forward eagerly + at::Tensor output, savemean, saveinvstd; + std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps); + + // make grad_outputs + at::Tensor grad_output = torch::randn_like(output); + at::Tensor grad_savemean = torch::zeros_like(savemean); + at::Tensor grad_saveinvstd = torch::zeros_like(saveinvstd); + + // run backward eagerly + at::Tensor grad_input, grad_weight, grad_bias; + // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight, + running_mean_eager, running_var_eager, + savemean, saveinvstd, training, eps, {true, true, true}); + + // make JIT graph + auto graph = std::make_shared(); + auto training_val = graph->insertConstant(IValue(training)); + auto momentum_val = graph->insertConstant(IValue(momentum)); + auto eps_val = graph->insertConstant(IValue(eps)); + + auto inputg = graph->addInput("self"); + auto weightg = graph->addInput("weight"); + auto biasg = graph->addInput("bias"); + auto running_meang = graph->addInput("running_mean"); + auto running_varg = graph->addInput("running_var"); + + Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val}); + auto outputs = bn->node()->outputs(); + for (auto output : outputs) { + graph->registerOutput(output); + } + LowerAllTuples(graph); + graph->lint(); + + // differentiate JIT graph + EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick + ConstantPropagation(graph); + auto grad_spec = differentiate(graph); + LowerGradOf(*grad_spec.df); + + // prepare JIT inputs / gradients + tensor_list tensors_in; + tensors_in.push_back(input); + tensors_in.push_back(weight); + tensors_in.push_back(bias); + tensors_in.push_back(running_mean_jit); + tensors_in.push_back(running_var_jit); + + tensor_list tensor_grads_in; + tensor_grads_in.push_back(grad_output); + tensor_grads_in.push_back(grad_savemean); + tensor_grads_in.push_back(grad_saveinvstd); + + // Get outputs from the interpreter + tensor_list tensors_out, tensor_grads_out; + std::tie(tensors_out, tensor_grads_out) = + runGradient(grad_spec, tensors_in, tensor_grads_in); + + // prepare expected structs + tensor_list expected_tensors_out, expected_tensor_grads_out; + expected_tensors_out.push_back(output); + expected_tensors_out.push_back(savemean); + expected_tensors_out.push_back(saveinvstd); + expected_tensors_out.push_back(running_mean_eager); + expected_tensors_out.push_back(running_var_eager); + expected_tensor_grads_out.push_back(grad_input); + expected_tensor_grads_out.push_back(grad_weight); + expected_tensor_grads_out.push_back(grad_bias); + + tensors_out.push_back(running_mean_jit); + tensors_out.push_back(running_var_jit); + + // Compare results + assertAllClose(tensors_out, expected_tensors_out); + assertAllClose(tensor_grads_out, expected_tensor_grads_out); +} + using var_meta_type = std::vector; using var_meta_list = std::vector; using test_fn_type = std::function; @@ -529,39 +744,6 @@ variable_list grad( fmap(inputs, get_edge)); } -void assertAllClose(const tensor_list& a, const tensor_list& b) { - ASSERT_EQ(a.size(), b.size()); - for (size_t i = 0; i < a.size(); ++i) { - ASSERT_TRUE(a[i].is_same_size(b[i])); - ASSERT_TRUE(a[i].allclose(b[i])); - } -} - -std::pair runGradient( - Gradient& grad_spec, - tensor_list& tensors_in, - tensor_list& tensor_grads_in) { - tensor_list tensors_out, tensor_grads_out; - Code f_code{grad_spec.f}, df_code{grad_spec.df}; - InterpreterState f_interpreter{f_code}, df_interpreter{df_code}; - - run(f_interpreter, tensors_in, tensors_out); - - tensor_list df_inputs; - df_inputs.insert( - df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end()); - for (auto offset : grad_spec.df_input_captured_inputs) - df_inputs.push_back(tensors_in[offset]); - for (auto offset : grad_spec.df_input_captured_outputs) - df_inputs.push_back(tensors_out[offset]); - run(df_interpreter, df_inputs, tensor_grads_out); - - // Outputs of f needs to be sliced - tensors_out.erase( - tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end()); - return std::make_pair(tensors_out, tensor_grads_out); -} - void testADFormulas() { const auto unwrap = [](const Variable& v) { return v.data(); }; diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index ba8579c3e5ca3..e28d65fdcaeb4 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -77,7 +77,9 @@ bool isDifferentiable(Node * n) { "aten::trunc(Tensor self) -> Tensor", "aten::log_softmax(Tensor self, int dim) -> Tensor", "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", - "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)" + "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)", + "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", + "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", }; // TODO: add support for the following fusible operators. @@ -401,6 +403,45 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val }); return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr}; + } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { + auto graph = node->owningGraph(); + auto backward_value = graph->insert(aten::thnn_conv2d_backward, { + grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + node->namedInput(attr::kernel_size), + node->namedInput(attr::stride), + node->namedInput(attr::padding), + outputs.at(1).value(), + outputs.at(2).value(), + graph->insertConstant(std::vector{true, true, true}) + }); + // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. + Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); + auto tuple_outputs = tuple_unpack_node->outputs(); + JIT_ASSERT(tuple_outputs.size() == size_t(3)); + return {tuple_outputs[0], tuple_outputs[1], nullptr, tuple_outputs[2], nullptr, nullptr}; + + } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) { + auto graph = node->owningGraph(); + auto backward_value = graph->insert(aten::native_batch_norm_backward, { + grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + inputs.at(3).value(), + inputs.at(4).value(), + outputs.at(1).value(), + outputs.at(2).value(), + inputs.at(5).value(), + inputs.at(7).value(), + graph->insertConstant(std::vector{true, true, true}) + }); + // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. + Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); + auto tuple_outputs = tuple_unpack_node->outputs(); + JIT_ASSERT(tuple_outputs.size() == size_t(3)); + return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr}; + } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) { JIT_ASSERT(grads.size() == 1); auto graph = node->owningGraph();