Skip to content

Commit

Permalink
Improve Variable interface (pytorch#5127)
Browse files Browse the repository at this point in the history
* Improve Variable interface

* Address comments from @apaszke and @colesbury

* string ::operator= is not noexcept

* Remove ir.h from tracer_state.h to improve build times

* Make Variable a struct and pack SavedVariable fields

* Implement as_variable_ref

* grad_fn_ptr() -> grad_fn_unsafe()

* Reduce hackiness of set_type hack

* Include variable.h and edge.h in tracer_state.h because it uses them

* class Variable -> struct Variable because Windows cant even

* Make Variable::output_nr uint32_t instead of int

* Add comment about tracing state

* Replaced more static_cast<Variable&> and improve docs

* Remove SavedVariable destructor and construct members in init list

* Clarify docs for Variable

* Variable::set_version -> set_version_counter
  • Loading branch information
goldsborough authored and soumith committed Feb 13, 2018
1 parent 0ef1038 commit 2d5fbe6
Show file tree
Hide file tree
Showing 38 changed files with 934 additions and 616 deletions.
1 change: 1 addition & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: true
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct Storage;
struct TensorImpl : public Retainable {
explicit TensorImpl(Type * type)
: is_scalar(false), type_(type) {}

Type & type() const {
return *type_;
}
Expand Down Expand Up @@ -49,7 +50,7 @@ struct TensorImpl : public Retainable {
void setScalar(bool s) {
is_scalar = s;
}
private:
protected:
bool is_scalar;
Type * type_;
};
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def run(self):
"torch/csrc/jit/python_ir.cpp",
"torch/csrc/jit/test_jit.cpp",
"torch/csrc/jit/tracer.cpp",
"torch/csrc/jit/tracer_state.cpp",
"torch/csrc/jit/python_tracer.cpp",
"torch/csrc/jit/passes/shape_analysis.cpp",
"torch/csrc/jit/interned_strings.cpp",
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_autograd_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def save_arg(arg, is_output):
name = arg['name']
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
saved_variables.append('SavedVariable {}_;'.format(name))
release_variables.append('{}_.data.reset();'.format(name))
release_variables.append('{}_.reset_data();'.format(name))
ptr = 'shared_from_this()' if is_output else ''
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':
Expand Down
6 changes: 3 additions & 3 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ Tensor select_backward_scalar(Tensor grad, const Tensor & input, const Tensor &
#ifdef WITH_SCALARS
grad_input.masked_fill_(input == value, grad);
#else
auto grad_data = static_cast<Variable&>(grad).data();
auto grad_data = as_variable_ref(grad).data();
grad_input.masked_fill_(input == value, Scalar(grad_data[0]));
#endif
return grad_input;
Expand Down Expand Up @@ -1088,9 +1088,9 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
for (auto s : input.sizes().slice(2)) {
M *= s;
}
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean), input);
auto mu = unsqueeze_dim1(make_variable(training ? save_mean : running_mean, /*requires_grad=*/false), input);
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5)), input);
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(make_variable(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), /*requires_grad=*/false), input);
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

Expand Down
91 changes: 37 additions & 54 deletions tools/autograd/templates/VariableType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/generated/Functions.h"
Expand All @@ -28,7 +29,6 @@ using namespace at;
using namespace torch::autograd::generated;

namespace torch { namespace autograd {

// Helper methods for working with Attributes (torch/csrc/jit/attributes.h)

// The overloaded accessors are convenient for the generated code (since we
Expand Down Expand Up @@ -74,7 +74,7 @@ std::unique_ptr<Storage> VariableType::storageWithAllocator(int64_t size, std::u
return baseType->storageWithAllocator(size, std::move(allocator));
}
Tensor VariableType::unsafeTensorFromTH(void * th_pointer, bool retain) const {
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), false);
return make_variable(baseType->unsafeTensorFromTH(th_pointer, retain), /*requires_grad=*/false);
}
std::unique_ptr<Generator> VariableType::generator() const {
return baseType->generator();
Expand Down Expand Up @@ -164,7 +164,7 @@ Variable & VariableType::checked_cast_variable(const Tensor & t, const char * na
runtime_error("Expected object of type Variable but found type %s for argument #%d '%s'",
t.type().toString(), pos, name);
}
return static_cast<Variable&>(const_cast<Tensor&>(t));
return as_variable_ref(const_cast<Tensor&>(t));
}

Tensor & VariableType::unpack(const Tensor & t, const char * name, int pos) {
Expand Down Expand Up @@ -207,49 +207,35 @@ static std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
return SavedVariable{tensor, false /* is output */}; });
}

static Tensor as_variable(Tensor tensor) {
return make_variable(std::move(tensor));
}

static std::tuple<Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))));
template <typename... Tensors, size_t... Is>
std::tuple<Tensors...> as_variable_impl(
std::tuple<Tensors...> tensors,
Indices<Is...>) {
// Expand the integer parameter pack into a sequence of Variable
// constructions. This turns into (boolean omitted):
// Variable(std::get<0>(tensors)), Variable(std::get<1>(tensors)), ...
return std::tuple<Tensors...>(
make_variable(std::get<Is>(tensors), /*requires_grad=*/false)...);
}

static std::tuple<Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))));
template <typename... Tensors>
std::tuple<Tensors...> as_variable(std::tuple<Tensors...> tensors) {
// `sizeof...(Tensors)` gets us the size of the `Tensors` parameter pack at
// compile time. We use it to parameterize a `MakeIndices` class, which will
// expand into an Indices object containing the numbers 0 to
// sizeof...(Tensors) - 1.
return as_variable_impl(
tensors, typename MakeIndices<sizeof...(Tensors)>::indices());
}

static std::tuple<Tensor, Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))),
make_variable(std::move(std::get<3>(tensors))));
}

static std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> tensors) {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))),
make_variable(std::move(std::get<3>(tensors))),
make_variable(std::move(std::get<4>(tensors)))
);
static Tensor as_variable(Tensor tensor) {
return make_variable(std::move(tensor), /*requires_grad=*/false);
}

static std::vector<Tensor> as_variable(TensorList tl) {
std::vector<Tensor> variables;
for (auto& t : tl) {
variables.emplace_back(make_variable(std::move(t)));
variables.emplace_back(make_variable(std::move(t), /*requires_grad=*/false));
}
return variables;
}
Expand Down Expand Up @@ -316,20 +302,20 @@ static void throw_error_out_requires_grad(const char* name) {

static void rebase_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
if (grad_fn && tensor.defined()) {
auto& var = static_cast<Variable&>(tensor);
auto& var = as_variable_ref(tensor);
grad_fn->num_inputs = 1;
var.rebase_history(0, std::move(grad_fn));
var.rebase_history({std::move(grad_fn), 0});
}
}

static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
grad_fn->num_inputs = tensors.size();
int output_nr = 0;
uint32_t output_nr = 0;
for (auto& tensor : tensors) {
if (tensor.defined()) {
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
var.rebase_history(output_nr, grad_fn);
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
var.rebase_history({grad_fn, output_nr});
}
output_nr++;
}
Expand All @@ -340,22 +326,20 @@ static void rebase_history(TensorList tensors, std::shared_ptr<Function> grad_fn
// overload for functions with multiple differentiable outputs.
static void set_history(Tensor& tensor, std::shared_ptr<Function> grad_fn) {
if (grad_fn && tensor.defined()) {
auto& var = static_cast<Variable&>(tensor);
auto& var = as_variable_ref(tensor);
grad_fn->num_inputs = 1;
var.get()->output_nr = 0;
var.get()->_grad_fn = std::move(grad_fn);
var.set_gradient_edge({std::move(grad_fn), 0});
}
}

static void set_history(TensorList tensors, std::shared_ptr<Function> grad_fn) {
if (grad_fn) {
grad_fn->num_inputs = tensors.size();
int64_t output_nr = 0;
uint32_t output_nr = 0;
for (auto& tensor : tensors) {
if (tensor.defined()) {
auto& var = static_cast<Variable&>(const_cast<Tensor&>(tensor));
var.get()->output_nr = output_nr;
var.get()->_grad_fn = grad_fn;
auto& var = as_variable_ref(const_cast<Tensor&>(tensor));
var.set_gradient_edge({grad_fn, output_nr});
}
output_nr++;
}
Expand All @@ -378,9 +362,8 @@ template<typename... Args> inline variable_list flatten(Args&&... args) {
return out; // RVO
}

static void increment_version(const Tensor & t) {
auto& var = static_cast<const Variable&>(t);
var.version_counter().increment();
static void increment_version(Tensor & t) {
as_variable_ref(t).bump_version();
}

static bool isFloatingPoint(ScalarType s) {
Expand Down Expand Up @@ -411,7 +394,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block

Tensor & VariableType::resize_(Tensor & self, IntList size) const {
auto& self_ = unpack(self, "self", 0);
if (static_cast<Variable&>(self).requires_grad()) {
if (as_variable_ref(self).requires_grad()) {
at::runtime_error("cannot resize variables that require grad");
}
baseType->resize_(self_, size);
Expand All @@ -421,7 +404,7 @@ Tensor & VariableType::resize_(Tensor & self, IntList size) const {
Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) const {
auto& self_ = unpack(self, "self", 0);
auto& the_template_ = unpack(the_template, "the_template", 1);
if (static_cast<Variable&>(self).requires_grad()) {
if (as_variable_ref(self).requires_grad()) {
at::runtime_error("cannot resize variables that require grad");
}
baseType->resize_as_(self_, the_template_);
Expand Down
5 changes: 4 additions & 1 deletion tools/autograd/templates/VariableType.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
// ${generated_comment}

#include <ATen/ATen.h>

#include <cstdint> // for size_t
#include <functional> // for function
#include <memory> // for unique_ptr
#include <string>
#include <vector>

Expand Down Expand Up @@ -56,7 +60,6 @@ struct VariableType final : public at::Type {
static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);

private:
at::Type* baseType;
std::string str;
};
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/templates/python_torch_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace torch::autograd::utils;
namespace torch { namespace autograd {

static Tensor set_requires_grad(Tensor self, bool requires_grad) {
static_cast<Variable&>(self).get()->_requires_grad = requires_grad;
as_variable_ref(self).set_requires_grad(requires_grad);
return self;
}

Expand Down Expand Up @@ -70,7 +70,7 @@ static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
{
HANDLE_TH_ERRORS
auto data = torch::utils::tensor_from_numpy(arg);
return THPVariable_Wrap(make_variable(std::move(data)));
return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false));
END_HANDLE_TH_ERRORS
}

Expand Down
10 changes: 4 additions & 6 deletions torch/csrc/autograd/edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

#include "torch/csrc/utils/hash.h"

namespace torch {
namespace autograd {
namespace torch { namespace autograd {

struct Function;

/// Represents a particular input of a function.
struct Edge {
Edge() noexcept : function(nullptr), input_nr(0) {}

Edge(const std::shared_ptr<Function>& function_, uint32_t input_nr_) noexcept
: function(function_), input_nr(input_nr_) {}
Edge(std::shared_ptr<Function> function_, uint32_t input_nr_) noexcept
: function(std::move(function_)), input_nr(input_nr_) {}

/// Convenience method to test if an edge is valid.
bool is_valid() const noexcept {
Expand All @@ -38,8 +37,7 @@ struct Edge {
/// The identifier of a particular input to the function.
uint32_t input_nr;
};
} // namespace autograd
} // namespace torch
}} // namespace torch::autograd

// The idiomatic way of enabling use of a custom type as the key of hash
// containers in C++11. This method removes the requirement of having to pass
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/autograd/function_hook.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include <memory>
#include <vector>

// A hook that's called on gradients
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/autograd/functions/basic_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"

#include <ATen/ATen.h>

#include <memory>
#include <utility>

Expand All @@ -19,7 +21,7 @@ auto DelayedError::apply(const variable_list& inputs) -> variable_list {
outputs.reserve(inputs.size());
for (auto& var : inputs) {
// FIXME: share version counters
outputs.emplace_back(var.defined() ? var.data() : Tensor());
outputs.emplace_back(var.defined() ? var.data() : at::Tensor());
}
return wrap_outputs(inputs, std::move(outputs), [&](function_list&& next_functions) {
return std::make_shared<Error>(msg, std::move(next_functions));
Expand Down
9 changes: 4 additions & 5 deletions torch/csrc/autograd/functions/special.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/edge.h"

#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -264,11 +265,10 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
// This output is already rebased. This happens when there
// the same Variable has been returned multiple times, and
// is repeated in this list.
if (output.get()->_grad_fn.get() == this) {
if (output.grad_fn_unsafe() == this) {
auto replicate = std::make_shared<Replicate>();
replicate->next_functions.emplace_back(this_shared, output.output_nr());
output.get()->_grad_fn = replicate;
output.get()->output_nr = 0;
output.set_gradient_edge({std::move(replicate), 0});
repeated_outputs.emplace(&output);
}
// NOTE: this check should be fairly cheap, and the set shouldn't
Expand All @@ -277,8 +277,7 @@ bool Eval::replaceSubgraph(const variable_list& inputs, const variable_list& _ou
auto & replicate = output.grad_fn();
replicate->next_functions.emplace_back(this_shared, num_inputs++);
} else {
output.get()->_grad_fn = this_shared;
output.get()->output_nr = num_inputs++;
output.set_gradient_edge(Edge(this_shared, num_inputs++));
}
}

Expand Down
Loading

0 comments on commit 2d5fbe6

Please sign in to comment.