Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Initial version of reverse mode autodiff #227

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

apaszke
Copy link
Contributor

@apaszke apaszke commented Mar 27, 2018

There are still a few things that could have been improved, but I think this can be done in subsequent PRs, and I wanted to get some feedback at this point. Some issues I see:

  1. It might be nicer to link it as a part of lang. On the other hand, this needs tc2halide, which is part of the core, and that would create a dependency cycle between libtc_lang.so and libtc_core.so.
  2. Not sure where the Python bindings should go. It's not really part of the engine, but I don't think we have a different pybind file that fits it better. I can create a new one if you want, just let me know.

@prigoyal
Copy link
Contributor

prigoyal commented Mar 27, 2018

yay :) @apaszke , I was wondering if you you give some high level idea of the approach for education purposes? Thanks for adding this :)

cc @abadams who might also have some idea on the autodiff related to Halide.

@apaszke
Copy link
Contributor Author

apaszke commented Mar 27, 2018

I can of course answer some more specific questions, but most of this patch is basically a tiny bit of reverse mode AD code, and a ton of defensive programming to perform bookkeeping and shield it from unsupported features.

@prigoyal
Copy link
Contributor

oh very cool, thanks for the reference @apaszke :) I'll look at that tutorial and then look over this PR.


using namespace tc;

void generalTest() {
Copy link
Contributor

@prigoyal prigoyal Mar 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious about two things atm so far from looking at the test case:

  1. how do we handle temporaries like the example https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/autograd_with_tc.html#reordering-grad-outputs ? would it be possible to demonstrate such use case with test example?

  2. How is the inputs/outputs order handled for backwards?

  3. Is it possible for users to specify that they want grad for only some inputs and not all?

Copy link
Contributor Author

@apaszke apaszke Mar 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding 1 and 2:

The test case has temporaries, so you can actually see this in the example. In general the grad TC currently takes in:

  • inputs
  • outputs (including intermediates)
  • grads w.r.t. outputs (including intermediates)

and returns grads w.r.t. inputs. Inputs and outputs are ordered in the same way as they were specified in the original TC.

Regarding 3:

No, that's not currently possible, but is a very straightforward extension (only need a few extra lines of code).

@prigoyal prigoyal requested review from zdevito and abadams March 30, 2018 20:56
@abadams
Copy link
Contributor

abadams commented Mar 30, 2018

The bulk of the work in the Halide autodiff was handling non-trivial indexing in ways that preserve parallelism. E.g. consider:

A(i) = 2 * B(some nasty expression in i)

You don't want to compute the derivatives as:

B'(some nasty expression in i) = A'(i) / 2

because you can't necessarily parallelize that over i, so it's going to be slow. This sort of thing comes up as soon as you have a stencil or broadcast (one input influences multiple outputs, when reversed, introduces a race condition).

Dealing with bounds and shapes of the computed derivatives was also interesting in the Halide work. I think that's simpler in TC, because tensors have finite size.

I think this PR is the right sort of approach, and we should do the rewrite of B(nasty) = A'(i) / 2 into B(j) = A'(nasty) / 2 as a later pass inside lowering, using the Halide solver or polyhedral tools.

@apaszke
Copy link
Contributor Author

apaszke commented Mar 31, 2018

@abadams Thanks for the review! Great, it seems that we're on the same page! 😄

I agree that shifting the formulas to lhs isn't perfect, but it could easily be done as a post processing step. I know how to implement this, it's just that I'd need a linear equation solver (which is not available at this level from what I understand). For now I decided to bail and put the indexing expressions there, but it's likely to cause errors downstream anyway (I don't think it's supported in TC/Halide). It would be useful to support this in general too.

Bounds are another problem, and you can see that I had to implement a few checks to make sure that they can be still auto-inferred (usedIndexVars + requireAllIndexVarsOf). This is still not perfect because there are a few simple things like where ranges, that also cause an error for now, but are fairly easy to add later. I didn't want to complicate the initial patch so we can quickly get this in, and then work on improvements.

@abadams
Copy link
Contributor

abadams commented Apr 1, 2018

Equation solvers exist once you get down to Halide IR or polyhedral IR. It would be silly to implement yet another one at the TC-front-end-IR level. I'd just allow complex index expressions on the LHS for now, which I think is what you're doing? They may barf inside the tc2halide layer right now, but support for those is an outstanding near-term TODO - we need them for things like histogram computation.

@apaszke
Copy link
Contributor Author

apaszke commented Apr 4, 2018

Yep, that's exactly what I've been thinking. Great that we're on the same page!

throw std::runtime_error("Unknown Halide type");
}

void findAccessedTensors(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have been implementing this type of functions as:

std::unordered_set<std::string> collectAccessedTensors(const TreeRef& tree) {...}

This way it can compose.
Generally we try not to pass out parameters as mutable in but just return the result.

void assertNoWriteAfterRead(Def def) {
std::unordered_set<std::string> read_only;
// Inputs are always read-only
for (Param input : def.params())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces even around 1-liner for/if please

auto lhs_name = comp.ident().name();
if (read_only.count(lhs_name) > 0)
throw std::runtime_error(
"AD not supported in TCs that write to a value after reading it");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this would not work for FCRelu?

def fcrelu(float(N, M) I, float(M, P) W, float(P) Bias) -> (O) {
  O(n, p) +=! W(r_m, p) * I(n, r_m)
  O(n, p)  = O(n, p) + B(p)
}

Is this a fundamental limitation or a future TODO?

}

void findAccessedTensors(
std::unordered_set<std::string>& read_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why convert to string at this point, returning TreeRef seems fine to me?

// that can happen is that we will be too conservative and throw, so it's ok.
std::unordered_set<std::string> usedIndexVars(Comprehension comp) {
std::unordered_set<std::string> index_vars;
for (Ident idx : comp.indices())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces plz

return has_zero_grad_.count(name) > 0;
}
void markZeroGrad(const std::string& name) {
has_zero_grad_.count(name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this function do anything?

@@ -0,0 +1,110 @@
/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extraction is useful by itself and independent, please split in a separate commit


std::string differentiate(const std::string& source) {
// Parse and check the source
auto def = Def(Sema().checkFunction(Parser(source).parseFunction()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the TC has more than 1 def?


using namespace lang;

static const lang::SourceRange dummyRange{std::make_shared<std::string>(""),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no freestanding globals please, make it a static definition inside a function in an anon namespace and return that instead

auto body = def.statements();
auto it = body.end();
if (it == body.begin())
throw std::runtime_error("empty body");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces around if
append TC to the message?

auto it = body.end();
if (it == body.begin())
throw std::runtime_error("empty body");
do {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a for / rbegin / rend loop?


int assign_kind = comp.assignment()->kind();
if (assign_kind != '=' && assign_kind != TK_PLUS_EQ_B &&
assign_kind != TK_PLUS_EQ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces plz

assign_kind != TK_PLUS_EQ)
throw ErrorReport(comp)
<< "Only =, += and +=! assignments are supported in AD";
if (comp.whereClauses().size() > 0 || comp.equivalent().present())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces plz

}
}

void findIndexVars(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above re return values vs mutable reference as input.

def grad_mm(double(M, K) A, double(K, N) B, double(M, N) O, double(M, N) W, double(M, N) seed_d_O, double(M, N) seed_d_W) -> (d_A, d_B) {
d_O(i, j) = seed_d_d_O(i, j)
d_O(i, j) += (seed_d_W(i, j) * 2)
d_A(i, k) = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to be parsed by TC this will need either a where clause or to be folded into +=!.

More generally I'd make the test run parsing + sema on the produced TC and print / compare against the canonical TC. Dogfooding would have caught the where case here.

// XXX: Sema isn't nilpotent, so we have to reparse the source
std::vector<TreeRef> inferOutputTypes(const std::string& source) {
auto halide_def =
tc2halide::translate(isl::with_exceptions::globalIslCtx(), source, true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this now makes lang depend on ISL which is unnecessary.
We should extract what we need from halide for range inference.

ListView<TreeRef>::create(dummyRange, TreeList{}),
Compound::create(TK_OPTION, dummyRange, {}),
ListView<TreeRef>::create(dummyRange, TreeList{})));
if (usedIndexVars(Comprehension(grad_comps_.back())) !=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces plz

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a first quick pass, generally this look great, thanks for your contribution.
I made a few comments, I'll finish reviewing early next week.

I have one question regarding gathers, I did not seem to see a guard against those in AD (i.e. A(B(i))), can you point where such cases are caught and what we throw on them?

@facebook-github-bot
Copy link

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours has expired.

Before we can review or merge your code, we need you to email cla@fb.com with your details so we can update your status.

@zdevito zdevito removed request for zdevito and abadams March 2, 2020 21:13
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants