Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
ebfbc03
separating out things
sbrantq Mar 21, 2025
a479207
Add trace op
sbrantq Mar 23, 2025
bdd5957
clean up
sbrantq Mar 23, 2025
bcb6943
format
sbrantq Mar 23, 2025
a436e25
rename
sbrantq Mar 23, 2025
962568c
improve op
sbrantq Mar 25, 2025
716f30c
improve test
sbrantq Mar 25, 2025
be7a16b
Force logpdf func
sbrantq Mar 26, 2025
b9d724b
verify logpdf
sbrantq Mar 26, 2025
3ed105e
save
sbrantq Apr 4, 2025
402caa9
simulate op
sbrantq Apr 10, 2025
8eb4eba
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 10, 2025
a77b141
suppress
sbrantq Apr 10, 2025
094327a
fix
sbrantq Apr 10, 2025
ded5990
condOp
sbrantq Apr 10, 2025
e3953aa
Bayesian linear regression test
sbrantq Apr 10, 2025
b70a6e6
Refactor stuff to use probprogutils
sbrantq Apr 18, 2025
38f2b85
scf test (not working)
sbrantq Apr 18, 2025
971da10
remove logpdf
sbrantq Apr 22, 2025
dd01e86
direct pp model call
sbrantq Apr 23, 2025
fd5084b
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 24, 2025
447201e
fix build
sbrantq Apr 24, 2025
8333afc
fix up
sbrantq Apr 25, 2025
6511a8d
merge
sbrantq Apr 30, 2025
36691e5
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 30, 2025
707b882
simple generate test
sbrantq May 2, 2025
0128e0f
remove call test
sbrantq May 14, 2025
4107d2e
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq May 17, 2025
94c684f
undo merge hack
sbrantq May 17, 2025
03f3034
clean up and fix roundtrip test
sbrantq May 21, 2025
8efe677
format
sbrantq May 21, 2025
050f67d
Passing symbol ptr to sample ops too. Adding simulate op
sbrantq May 27, 2025
385ed5d
Merge commit 'c3a95700012478aaac5c8a2a898e56a568f4769f' into probprog…
sbrantq May 31, 2025
be4894d
change to anytype
sbrantq Jun 1, 2025
8b56126
asm format for initTrace
sbrantq Jun 1, 2025
c4afd9e
fix name
sbrantq Jun 1, 2025
adb7b72
asm format for addSampleToTraceOp
sbrantq Jun 2, 2025
825b93b
type preserving hacks
sbrantq Jun 2, 2025
bda0d6f
Refactor: making pointers ui64 attributes, updating tests
sbrantq Jun 10, 2025
989ff23
remove initTraceOp
sbrantq Jun 10, 2025
614821e
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jun 13, 2025
5729a66
traced_output_indices attr to specify which output to trace
sbrantq Jun 14, 2025
23bfd0d
Real generate op without recusive support
sbrantq Jun 23, 2025
3e6f8c0
simplify
sbrantq Jun 23, 2025
72ab30e
encoding constraints as a custom attribute
sbrantq Jun 24, 2025
7ed5126
simplify attr
sbrantq Jun 24, 2025
e162dfe
add checks
sbrantq Jun 25, 2025
453312c
postpasses
sbrantq Jun 25, 2025
16b819d
bug fix for aliasing outputs
sbrantq Jun 26, 2025
56560f3
more tests
sbrantq Jun 26, 2025
3390fd0
untraced call
sbrantq Jun 26, 2025
de39226
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jun 27, 2025
6040f76
add trace type
sbrantq Jun 27, 2025
9c59cd8
refactored simulate op simple case
sbrantq Jun 29, 2025
d554f4c
simulate op should return weights too
sbrantq Jun 30, 2025
c9ca205
greedy rewrite simulate ops
sbrantq Jun 30, 2025
747d8f6
bug fix
sbrantq Jun 30, 2025
3f585fc
update tests
sbrantq Jun 30, 2025
f50013b
symbol attr to addsubtrace op
sbrantq Jul 1, 2025
3b70c56
no dump
sbrantq Jul 2, 2025
8694452
return updated traces to enforce data dependences
sbrantq Jul 2, 2025
54dcddd
add weights and return values to trace too
sbrantq Jul 4, 2025
aa5bbc2
adding traced_output_indices attr to simulate op
sbrantq Jul 4, 2025
ebe8af7
rearrange op defs; adding new generate op with trace operand
sbrantq Jul 4, 2025
3c1d086
WIP generate op
sbrantq Jul 7, 2025
4ebef8e
WIP generate op
sbrantq Jul 7, 2025
ed55d19
account for new constraint argument
sbrantq Jul 7, 2025
4ad792b
refactored generate op & simple test
sbrantq Jul 7, 2025
157b1a6
fix gsfcop result indices
sbrantq Jul 7, 2025
5bdd2aa
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jul 14, 2025
fd633fa
test following calling convention
sbrantq Jul 16, 2025
7385aee
enforcing probprog calling convention (rng state being the 0th input/…
sbrantq Jul 17, 2025
f3d28b9
cleanup obsolete tests
sbrantq Jul 17, 2025
487666d
format
sbrantq Jul 18, 2025
e6ae96a
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jul 18, 2025
22916cb
dont print
sbrantq Jul 18, 2025
614aa70
improve
sbrantq Jul 18, 2025
db3a3f9
clean up
sbrantq Jul 18, 2025
f896b6a
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jul 20, 2025
3846d6b
minor: get rid of initConstraint op
sbrantq Jul 28, 2025
97864db
generate op fixup: constrained_symbols --> constrained_addresses
sbrantq Jul 28, 2025
3af6546
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jul 29, 2025
d55e8fc
Merge branch 'main' into probprog-trace-operand
sbrantq Jul 31, 2025
3a05a27
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jul 31, 2025
5150607
Merge branch 'probprog-trace-operand' of https://github.com/EnzymeAD/…
sbrantq Jul 31, 2025
a011faf
fix build
sbrantq Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 231 additions & 8 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -254,25 +254,248 @@ def BroadcastOp : Enzyme_Op<"broadcast"> {
];
}

def IgnoreDerivativesOp : Enzyme_Op<"ignore_derivatives",
[Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> {
let summary = "Prevents the flow of gradients (and higher-order derivatives) by creating a new value that is detached from the original value. This is an identity operation on the primal.";
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);

let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($output)
}];
}

// Probabilistic programming
def Trace : Enzyme_Type<"Trace"> {
let summary = "Execution trace for probabilistic programs";
let description = [{
Reference: https://www.gen.dev/docs/stable/tutorials/basics/gfi/#Traces
Mutable storage for mapping all sampled symbols to their values, log-likelihoods,
and various other information from executing a traced probabilistic function.
}];
let mnemonic = "Trace";
}

def SymbolAttr : Enzyme_Attr<"Symbol", "symbol"> {
let summary = "Symbol associated with a Sample op";
let description = [{
Symbol associated with a Sample op.
}];
let parameters = (ins "uint64_t":$ptr);
let assemblyFormat = "`<` $ptr `>`";
}

def AddressAttr : TypedArrayAttrBase<SymbolAttr, "Address as an array of symbols">;
def AddressArrayAttr : TypedArrayAttrBase<AddressAttr, "Array of addresses">;

def Constraint : Enzyme_Type<"Constraint"> {
let summary = "A mutable storage mapping symbols to value constraints";
let description = [{
A mutable storage mapping symbols to value constraints.
}];
let mnemonic = "Constraint";
}

def SampleOp : Enzyme_Op<"sample",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Sample from a distribution. Arguments to the distribution are: a random number generator object, followed by arguments to the sample op itself";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, DefaultValuedStrAttr<StrAttr, "">:$name);
let summary = "Sample from a distribution";
let description = [{
Sample from a distribution. By convention, the 0th operand in `inputs`
or `outputs` is the initial RNG state (seed).
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
OptionalAttr<FlatSymbolRefAttr>:$logpdf,
OptionalAttr<SymbolAttr>:$symbol,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def IgnoreDerivativesOp : Enzyme_Op<"ignore_derivatives",
[Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> {
let summary = "Prevents the flow of gradients (and higher-order derivatives) by creating a new value that is detached from the original value. This is an identity operation on the primal.";
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
def UntracedCallOp : Enzyme_Op<"untracedCall"> {
let summary = "Call a probabilistic function without tracing";
let description = [{
Call a probabilistic function without tracing. By convention, the 0th operand in `inputs`
or `outputs` is the initial RNG state (seed).
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($output)
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Simulate a probabilistic function to generate execution trace";
let description = [{
Simulate a probabilistic function to generate execution trace
by replacing all SampleOps with distribution calls and recording
all sampled values into the trace. This op returns the trace, the weight
(accumulated log-probability), and the other outputs. By convention,
the 0th operand in `inputs` or `outputs` is the initial RNG state (seed).
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Generate an execution trace and weight from a probabilistic function";
let description = [{
Generate an execution trace and weight from a probabilistic function.
If a `constraint` dict is provided AND the sample op's `symbol` is in the
`constrained_symbols` array, we will use the corresponding constraint value
instead of generating new samples from the probabilistic function.
By convention, the 0th operand in `inputs` or `outputs` is the initial RNG
state (seed).
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
AddressArrayAttr:$constrained_addresses,
Constraint:$constraint,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type($inputs, results)
}];
}

def InitTraceOp : Enzyme_Op<"initTrace"> {
let summary = "Initialize an execution trace for a probabilistic function";
let description = [{
Initialize an execution trace for a probabilistic function.
}];
let arguments = (ins );
let results = (outs Trace:$trace);
let assemblyFormat = "attr-dict `:` type($trace)";
}

def AddSampleToTraceOp : Enzyme_Op<"addSampleToTrace"> {
let summary = "Add a sampled value into the execution trace";
let description = [{
Add a sampled value into the execution trace.
}];

let arguments = (ins
Trace:$trace,
SymbolAttr:$symbol,
Variadic<AnyType>:$sample
);

let results = (outs Trace:$updated_trace);

let assemblyFormat = [{
`(` $sample `:` type($sample) `)` `into` $trace attr-dict
}];
}

def AddSubtraceOp : Enzyme_Op<"addSubtrace"> {
let summary = "Insert a subtrace into a parent trace";
let description = [{
Insert a subtrace into a parent trace.
}];

let arguments = (ins
Trace:$subtrace,
SymbolAttr:$symbol,
Trace:$trace
);

let results = (outs Trace:$updated_trace);

let assemblyFormat = [{
$subtrace `into` $trace attr-dict
}];
}

def AddWeightToTraceOp : Enzyme_Op<"addWeightToTrace"> {
let summary = "Add the aggregated weight into the execution trace";
let description = [{
Add the aggregated log-probability weight to the execution trace.
}];

let arguments = (ins
Trace:$trace,
AnyType:$weight
);

let results = (outs Trace:$updated_trace);

let assemblyFormat = [{
`(` $weight `:` type($weight) `)` `into` $trace attr-dict
}];
}

def AddRetvalToTraceOp : Enzyme_Op<"addRetvalToTrace"> {
let summary = "Add the function's return value(s) into the execution trace";
let description = [{
Add the function's return value(s) into the execution trace.
}];

let arguments = (ins
Trace:$trace,
Variadic<AnyType>:$retval
);

let results = (outs Trace:$updated_trace);

let assemblyFormat = [{
`(` $retval `:` type($retval) `)` `into` $trace attr-dict
}];
}

def GetSampleFromConstraintOp : Enzyme_Op<"getSampleFromConstraint"> {
let summary = "Get sampled values from a constraint for a given symbol";
let description = [{
Get sampled values from a constraint for a given symbol.
}];
let arguments = (ins Constraint:$constraint, SymbolAttr:$symbol);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$constraint attr-dict `:` type(results)
}];
}

def GetSubconstraintOp : Enzyme_Op<"getSubconstraint"> {
let summary = "Get a subconstraint from a constraint for a given symbol";
let description = [{
Get a subconstraint from a constraint for a given symbol.
}];
let arguments = (ins Constraint:$constraint, SymbolAttr:$symbol);

let results = (outs Constraint:$subconstraint);

let assemblyFormat = [{
$constraint attr-dict
}];
}

Expand Down
73 changes: 57 additions & 16 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,6 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
build(builder, result, resultTy, input, shapeAttr);
}

//===----------------------------------------------------------------------===//
// SampleOp
//===----------------------------------------------------------------------===//

LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

/**
*
* Modifies activities for the AutoDiffOp.
Expand Down Expand Up @@ -713,3 +697,60 @@ void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ReverseRetOpt>(context);
}

//===----------------------------------------------------------------------===//
// SampleOp
//===----------------------------------------------------------------------===//

LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

if (getLogpdfAttr()) {
auto global = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
*this, getLogpdfAttr());
if (!global)
return emitOpError("'")
<< getLogpdf().value() << "' does not reference a valid global "
<< "funcOp";
}

return success();
}

//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//

LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

//===----------------------------------------------------------------------===//
// SimulateOp
//===----------------------------------------------------------------------===//

LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(MLIREnzymeAutoDiffInterface
CloneFunction.cpp
GradientUtils.cpp
GradientUtilsReverse.cpp
ProbProgUtils.cpp
EnzymeLogic.cpp
EnzymeLogicReverse.cpp

Expand Down
Loading
Loading