Skip to content

Commit

Permalink
Merge branch 'master' into autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
NeuralCoder3 committed Oct 21, 2022
2 parents b1772d4 + 2a38a50 commit 639b8fd
Show file tree
Hide file tree
Showing 53 changed files with 135 additions and 131 deletions.
2 changes: 1 addition & 1 deletion dialects/autodiff/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ inline const Def* op_autodiff(const Def* fun) {
World& world = fun->world();
// We rely on the normalized thorin convention that all arguments in functions are grouped.
// `cn[[args], cont:=cn[returns]]`
return world.app(world.app(world.ax<autodiff>(), fun->type()), fun);
return world.app(world.app(world.ax<ad>(), fun->type()), fun);
}

inline const Def* op_zero(const Def* A) {
Expand Down
11 changes: 4 additions & 7 deletions dialects/autodiff/autodiff.thorin
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@
///
/// ## Types
///
.ax %autodiff.tangent_type: * -> *, normalize_tangent_type;
.ax %autodiff.Tangent: * -> *, normalize_Tangent;
///
/// ## Operations
///
/// ### %autodiff.autodiff
/// ### %autodiff.ad
///
/// This axiom operates on functions and types.
///
/// For function types the augmented type is computed: `(T -> U) => (T -> U × (U -> T))`
.ax %autodiff.autodiff_type: * -> *,
normalize_autodiff_type;
.ax %autodiff.AD: * -> *, normalize_AD;
/// On closed terms (functions, operators, ho arguments, registered axioms, etc.) the augmented term is returned.
/// The augmented term `f'` returns the result together with the pullback.
/// `autodiff f = f' = λ args. (f args, f*)`
.ax %autodiff.autodiff: Π [T: *] -> T ->
%autodiff.autodiff_type T,
normalize_autodiff;
.ax %autodiff.ad: Π [T: *] -> T -> %autodiff.AD T, normalize_ad;
///
/// ### %autodiff.zero
///
Expand Down
6 changes: 3 additions & 3 deletions dialects/autodiff/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ namespace thorin::autodiff {

/// Currently this normalizer does nothin.
/// TODO: Maybe we want to handle trivial lookup replacements here.
const Def* normalize_autodiff(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
const Def* normalize_ad(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& world = type->world();
return world.raw_app(callee, arg, dbg);
}

const Def* normalize_autodiff_type(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
const Def* normalize_AD(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
auto& world = type->world();
auto ad_ty = autodiff_type_fun(arg);
if (ad_ty) return ad_ty;
return world.raw_app(callee, arg, dbg);
}

const Def* normalize_tangent_type(const Def*, const Def*, const Def* arg, const Def*) { return tangent_type_fun(arg); }
const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*) { return tangent_type_fun(arg); }

/// Currently this normalizer does nothing.
/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
Expand Down
2 changes: 1 addition & 1 deletion dialects/autodiff/passes/autodiff_eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const Def* AutoDiffEval::derive(const Def* def) {
}

const Def* AutoDiffEval::rewrite(const Def* def) {
if (auto ad_app = match<autodiff>(def); ad_app) {
if (auto ad_app = match<ad>(def); ad_app) {
// callee = autodiff T
// arg = function of type T
// (or operator)
Expand Down
19 changes: 12 additions & 7 deletions dialects/core/be/ll/ll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,18 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {
unreachable();
} else if (def->isa<Bot>()) {
return "undef";
} else if (auto bit = match<core::bit2>(def)) {
auto [a, b] = bit->args<2>([this](auto def) { return emit(def); });
auto t = convert(bit->type());

auto neg = [&](std::string_view x) { return bb.assign(name + ".neg", "xor {} 0, {}", t, x); };

switch (bit.id()) {
} else if (auto bit1 = match<core::bit1>(def)) {
assert(bit1.id() == core::bit1::neg);
auto x = emit(bit1->arg());
auto t = convert(bit1->type());
return bb.assign(name, "xor {} -1, {}", t, x);
} else if (auto bit2 = match<core::bit2>(def)) {
auto [a, b] = bit2->args<2>([this](auto def) { return emit(def); });
auto t = convert(bit2->type());

auto neg = [&](std::string_view x) { return bb.assign(name + ".neg", "xor {} -1, {}", t, x); };

switch (bit2.id()) {
// clang-format off
case core::bit2::_and: return bb.assign(name, "and {} {}, {}", t, a, b);
case core::bit2:: _or: return bb.assign(name, "or {} {}, {}", t, a, b);
Expand Down
4 changes: 2 additions & 2 deletions dialects/core/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ const Def* normalize_nop(const Def* type, const Def* callee, const Def* arg, con
}
}

return world.raw_app(callee, a, dbg);
return world.raw_app(callee, arg, dbg);
}

template<ncmp id>
Expand Down Expand Up @@ -379,7 +379,7 @@ const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg, co
}
}

return world.raw_app(callee, a, dbg);
return world.raw_app(callee, arg, dbg);
}

template<rop id>
Expand Down
9 changes: 4 additions & 5 deletions docs/langref.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ The following tables comprise all production rules:
| d | `.pack` Sym ( `:` e<sub>type</sub> )? `,` e<sub>shape</sub> v? n | | nominal pack declaration | thorin::Pack |
| d | `.Sigma` Sym ( `:` e<sub>type</sub> )? `,` L<sub>arity</sub> v? n | | nominal sigma declaration | thorin::Sigma |
| d | `.def` Sym n | | nominal definition | nominals |
| v | `,` `@` Sym \| `,` `@` `(` Sym `,` ... `,` Sym `)` | | nominal variable declaration | nominals |
| n | `;` \| o | | nominal definition | - |
| o | `=` e `;` | | operand of nominal definition | - |
| o | `=` de `;` | | operand of nominal definition | - |
| o | `=` `{` e `,` ... `,` e `}` `;` || operands of nominal definition | - |

### Patterns
Expand All @@ -162,18 +161,19 @@ For this reason there is no rule `b -> s (p, ..., p)`.

| Nonterminal | Right-Hand Side | New Scope? | Comment | Thorin Class |
|-------------|-------------------------------------------------------------------------------|------------|--------------------------------------|-----------------|
| de | d\* e | | declaration expression | - |
| e | `.Univ` | | universise: type of a type level | thorin::Univ |
| e | `.Type` e | | type of level e | thorin::Type |
| e | `*` | | alias for `.Type (0:.Univ)` | thorin::Type |
| e | `` | | alias for `.Type (1:.Univ)` | thorin::Type |
| e | `.Nat` | | natural number | thorin::Nat |
| e | `.Idx` | | builtin constant of type `.Nat -> *` | thorin::Idx |
| e | `.Bool` | | alias for `.Idx 2` | thorin::Idx |
| e | `{` e `}` || block | - |
| e | `{` de `}` || block | - |
| e | L `:` e<sub>type</sub> | | literal | thorin::Lit |
| e | `.ff` | | alias for `0_2` | thorin::Lit |
| e | `.tt` | | alias for `1_2` | thorin::Lit |
| e | ( `.bot` or `.top` ) ( `:` e<sub>type</sub> )? | | bottom/top | thorin::TExt |
| e | ( `.bot` \| `.top` ) ( `:` e<sub>type</sub> )? | | bottom/top | thorin::TExt |
| e | Sym | | identifier | - |
| e | Ax | | use of an axiom | - |
| e | e e | | application | thorin::App |
Expand All @@ -187,7 +187,6 @@ For this reason there is no rule `b -> s (p, ..., p)`.
| e | `[` b `,` ... `,` b `]` || sigma | thorin::Sigma |
| e | `` i e<sub>shape</sub> `;` e<sub>body</sub>`` || pack | thorin::Pack |
| e | `«` i e<sub>shape</sub> `;` e<sub>body</sub>`»` || array | thorin::Arr |
| e | d e | | declaration | - |

An elided type of
* a literal defaults to `.Nat`,
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/2out.thorin.disabled
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand All @@ -24,7 +24,7 @@
pb((1:I32,0:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [I32,.Cn[I32,I32]]) f;
.let f_diff = %autodiff.ad (.Cn [I32,.Cn[I32,I32]]) f;
.let c = (42:I32);
f_diff (c,ret_cont)
};
4 changes: 2 additions & 2 deletions lit/autodiff/autodiff_mult_in_call.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -32,7 +32,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [[I32,I32],.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [[I32,I32],.Cn[I32]]) f;

.let c = (42:I32,43:I32);
f_diff (c,ret_cont)
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/autodiff_mult_in_call2.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -29,7 +29,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [[I32,I32],.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [[I32,I32],.Cn[I32]]) f;

.let c = (42:I32,43:I32);
f_diff (c,ret_cont)
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/autodiff_mult_in_call2_2.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -29,7 +29,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [[I32,I32],.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [[I32,I32],.Cn[I32]]) f;

.let c = (42:I32,43:I32);
f_diff (c,ret_cont)
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/autodiff_mult_in_call2_3.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -29,7 +29,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [[I32,I32],.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [[I32,I32],.Cn[I32]]) f;

.let c = (42:I32,43:I32);
f_diff (c,ret_cont)
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/autodiff_mult_in_call_cont.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -28,7 +28,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [[I32,I32],.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [[I32,I32],.Cn[I32]]) f;

.let c = (42:I32,43:I32);
f_diff (c,ret_cont)
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/call_autodiff.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -30,7 +30,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [I32,.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [I32,.Cn[I32]]) f;
.let f_diff_cast = f_diff;

.let c = (42:I32);
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/call_autodiff_cont.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down Expand Up @@ -34,7 +34,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [I32,.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [I32,.Cn[I32]]) f;

.let c = (42:I32);
f_diff (c,ret_cont)
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/2out.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/42.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/add_tuple.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import direct;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/cps_inline.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin %s --output-ll %t.ll -o - | FileCheck %s

.import mem;
.import core;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/ds_inline.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import mem;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/invoke_ds.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import direct;
Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/simple_real.thorin.disabled
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import autodiff;
.import mem;
Expand Down
8 changes: 4 additions & 4 deletions lit/autodiff/general/tangent_type_cast.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import mem;
Expand All @@ -14,13 +14,13 @@

.cn .extern main [mem : %mem.M, argc : I32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, I32]] = {

.cn ret_cont r::[%autodiff.tangent_type I32] = {
.let r2=%core.bitcast (I32,(%autodiff.tangent_type I32)) r;
.cn ret_cont r::[%autodiff.Tangent I32] = {
.let r2=%core.bitcast (I32,(%autodiff.Tangent I32)) r;
return (mem, r2)
};

.cn ret_wrap r::[I32] = {
.let r2=%core.bitcast ((%autodiff.tangent_type I32),I32) r;
.let r2=%core.bitcast ((%autodiff.Tangent I32),I32) r;
ret_cont r2
};

Expand Down
2 changes: 1 addition & 1 deletion lit/autodiff/general/zero_tuple.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import direct;
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/id_autodiff.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand All @@ -23,7 +23,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [I32,.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [I32,.Cn[I32]]) f;
.let f_diff_cast = f_diff;

.let c = (43:I32);
Expand Down
4 changes: 2 additions & 2 deletions lit/autodiff/id_autodiff_info_out.thorin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s
// RUN: %thorin -d direct -d autodiff %s --output-ll %t.ll -o - | FileCheck %s

.import core;
.import autodiff;
Expand All @@ -23,7 +23,7 @@
pb((1:I32),pb_ret_cont)
};

.let f_diff = %autodiff.autodiff (.Cn [I32,.Cn[I32]]) f;
.let f_diff = %autodiff.ad (.Cn [I32,.Cn[I32]]) f;
.let f_diff_cast = f_diff;

.let c = (43:I32);
Expand Down
Loading

0 comments on commit 639b8fd

Please sign in to comment.