-
Notifications
You must be signed in to change notification settings - Fork 212
Initial version of reverse mode autodiff #227
base: master
Are you sure you want to change the base?
Conversation
3a92daa
to
95dcba6
Compare
95dcba6
to
01697c6
Compare
I can of course answer some more specific questions, but most of this patch is basically a tiny bit of reverse mode AD code, and a ton of defensive programming to perform bookkeeping and shield it from unsupported features. |
oh very cool, thanks for the reference @apaszke :) I'll look at that tutorial and then look over this PR. |
|
||
using namespace tc; | ||
|
||
void generalTest() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious about two things atm so far from looking at the test case:
-
how do we handle temporaries like the example https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/autograd_with_tc.html#reordering-grad-outputs ? would it be possible to demonstrate such use case with test example?
-
How is the inputs/outputs order handled for backwards?
-
Is it possible for users to specify that they want grad for only some inputs and not all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding 1 and 2:
The test case has temporaries, so you can actually see this in the example. In general the grad TC currently takes in:
- inputs
- outputs (including intermediates)
- grads w.r.t. outputs (including intermediates)
and returns grads w.r.t. inputs. Inputs and outputs are ordered in the same way as they were specified in the original TC.
Regarding 3:
No, that's not currently possible, but is a very straightforward extension (only need a few extra lines of code).
The bulk of the work in the Halide autodiff was handling non-trivial indexing in ways that preserve parallelism. E.g. consider: A(i) = 2 * B(some nasty expression in i) You don't want to compute the derivatives as: B'(some nasty expression in i) = A'(i) / 2 because you can't necessarily parallelize that over i, so it's going to be slow. This sort of thing comes up as soon as you have a stencil or broadcast (one input influences multiple outputs, when reversed, introduces a race condition). Dealing with bounds and shapes of the computed derivatives was also interesting in the Halide work. I think that's simpler in TC, because tensors have finite size. I think this PR is the right sort of approach, and we should do the rewrite of B(nasty) = A'(i) / 2 into B(j) = A'(nasty) / 2 as a later pass inside lowering, using the Halide solver or polyhedral tools. |
@abadams Thanks for the review! Great, it seems that we're on the same page! 😄 I agree that shifting the formulas to lhs isn't perfect, but it could easily be done as a post processing step. I know how to implement this, it's just that I'd need a linear equation solver (which is not available at this level from what I understand). For now I decided to bail and put the indexing expressions there, but it's likely to cause errors downstream anyway (I don't think it's supported in TC/Halide). It would be useful to support this in general too. Bounds are another problem, and you can see that I had to implement a few checks to make sure that they can be still auto-inferred ( |
Equation solvers exist once you get down to Halide IR or polyhedral IR. It would be silly to implement yet another one at the TC-front-end-IR level. I'd just allow complex index expressions on the LHS for now, which I think is what you're doing? They may barf inside the tc2halide layer right now, but support for those is an outstanding near-term TODO - we need them for things like histogram computation. |
Yep, that's exactly what I've been thinking. Great that we're on the same page! |
throw std::runtime_error("Unknown Halide type"); | ||
} | ||
|
||
void findAccessedTensors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have been implementing this type of functions as:
std::unordered_set<std::string> collectAccessedTensors(const TreeRef& tree) {...}
This way it can compose.
Generally we try not to pass out parameters as mutable in but just return the result.
void assertNoWriteAfterRead(Def def) { | ||
std::unordered_set<std::string> read_only; | ||
// Inputs are always read-only | ||
for (Param input : def.params()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces even around 1-liner for/if please
auto lhs_name = comp.ident().name(); | ||
if (read_only.count(lhs_name) > 0) | ||
throw std::runtime_error( | ||
"AD not supported in TCs that write to a value after reading it"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this would not work for FCRelu?
def fcrelu(float(N, M) I, float(M, P) W, float(P) Bias) -> (O) {
O(n, p) +=! W(r_m, p) * I(n, r_m)
O(n, p) = O(n, p) + B(p)
}
Is this a fundamental limitation or a future TODO?
} | ||
|
||
void findAccessedTensors( | ||
std::unordered_set<std::string>& read_only, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, why convert to string at this point, returning TreeRef seems fine to me?
// that can happen is that we will be too conservative and throw, so it's ok. | ||
std::unordered_set<std::string> usedIndexVars(Comprehension comp) { | ||
std::unordered_set<std::string> index_vars; | ||
for (Ident idx : comp.indices()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces plz
return has_zero_grad_.count(name) > 0; | ||
} | ||
void markZeroGrad(const std::string& name) { | ||
has_zero_grad_.count(name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this function do anything?
@@ -0,0 +1,110 @@ | |||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extraction is useful by itself and independent, please split in a separate commit
|
||
std::string differentiate(const std::string& source) { | ||
// Parse and check the source | ||
auto def = Def(Sema().checkFunction(Parser(source).parseFunction())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if the TC has more than 1 def?
|
||
using namespace lang; | ||
|
||
static const lang::SourceRange dummyRange{std::make_shared<std::string>(""), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no freestanding globals please, make it a static definition inside a function in an anon namespace and return that instead
auto body = def.statements(); | ||
auto it = body.end(); | ||
if (it == body.begin()) | ||
throw std::runtime_error("empty body"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces around if
append TC to the message?
auto it = body.end(); | ||
if (it == body.begin()) | ||
throw std::runtime_error("empty body"); | ||
do { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use a for / rbegin / rend loop?
|
||
int assign_kind = comp.assignment()->kind(); | ||
if (assign_kind != '=' && assign_kind != TK_PLUS_EQ_B && | ||
assign_kind != TK_PLUS_EQ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces plz
assign_kind != TK_PLUS_EQ) | ||
throw ErrorReport(comp) | ||
<< "Only =, += and +=! assignments are supported in AD"; | ||
if (comp.whereClauses().size() > 0 || comp.equivalent().present()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces plz
} | ||
} | ||
|
||
void findIndexVars( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above re return values vs mutable reference as input.
def grad_mm(double(M, K) A, double(K, N) B, double(M, N) O, double(M, N) W, double(M, N) seed_d_O, double(M, N) seed_d_W) -> (d_A, d_B) { | ||
d_O(i, j) = seed_d_d_O(i, j) | ||
d_O(i, j) += (seed_d_W(i, j) * 2) | ||
d_A(i, k) = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to be parsed by TC this will need either a where clause or to be folded into +=!
.
More generally I'd make the test run parsing + sema on the produced TC and print / compare against the canonical TC. Dogfooding would have caught the where case here.
// XXX: Sema isn't nilpotent, so we have to reparse the source | ||
std::vector<TreeRef> inferOutputTypes(const std::string& source) { | ||
auto halide_def = | ||
tc2halide::translate(isl::with_exceptions::globalIslCtx(), source, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this now makes lang depend on ISL which is unnecessary.
We should extract what we need from halide for range inference.
ListView<TreeRef>::create(dummyRange, TreeList{}), | ||
Compound::create(TK_OPTION, dummyRange, {}), | ||
ListView<TreeRef>::create(dummyRange, TreeList{}))); | ||
if (usedIndexVars(Comprehension(grad_comps_.back())) != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces plz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a first quick pass, generally this look great, thanks for your contribution.
I made a few comments, I'll finish reviewing early next week.
I have one question regarding gathers, I did not seem to see a guard against those in AD (i.e. A(B(i))
), can you point where such cases are caught and what we throw on them?
Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours has expired. Before we can review or merge your code, we need you to email cla@fb.com with your details so we can update your status. |
There are still a few things that could have been improved, but I think this can be done in subsequent PRs, and I wanted to get some feedback at this point. Some issues I see:
lang
. On the other hand, this needstc2halide
, which is part of thecore
, and that would create a dependency cycle betweenlibtc_lang.so
andlibtc_core.so
.