Skip to content

Commit d48bb60

Browse files
authored
ProbProg: Making trace an operand (EnzymeAD#2363)
* separating out things * Add trace op * clean up * format * rename * improve op * improve test * Force logpdf func * verify logpdf * save * simulate op * suppress * fix * condOp * Bayesian linear regression test * Refactor stuff to use probprogutils * scf test (not working) * remove logpdf * direct pp model call * fix build * fix up * merge * simple generate test * remove call test * undo merge hack * clean up and fix roundtrip test * format * Passing symbol ptr to sample ops too. Adding simulate op * change to anytype * asm format for initTrace * fix name * asm format for addSampleToTraceOp * type preserving hacks * Refactor: making pointers ui64 attributes, updating tests * remove initTraceOp * traced_output_indices attr to specify which output to trace * Real generate op without recusive support * simplify * encoding constraints as a custom attribute * simplify attr * add checks * postpasses * bug fix for aliasing outputs * more tests * untraced call * add trace type * refactored simulate op simple case * simulate op should return weights too * greedy rewrite simulate ops * bug fix * update tests * symbol attr to addsubtrace op * no dump * return updated traces to enforce data dependences * add weights and return values to trace too * adding traced_output_indices attr to simulate op * rearrange op defs; adding new generate op with trace operand * WIP generate op * WIP generate op * account for new constraint argument * refactored generate op & simple test * fix gsfcop result indices * test following calling convention * enforcing probprog calling convention (rng state being the 0th input/output operand) * cleanup obsolete tests * format * dont print * improve * clean up * minor: get rid of initConstraint op * generate op fixup: constrained_symbols --> constrained_addresses * fix build
1 parent c3da306 commit d48bb60

File tree

18 files changed

+1437
-35
lines changed

18 files changed

+1437
-35
lines changed

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 231 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,25 +254,248 @@ def BroadcastOp : Enzyme_Op<"broadcast"> {
254254
];
255255
}
256256

257+
def IgnoreDerivativesOp : Enzyme_Op<"ignore_derivatives",
258+
[Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> {
259+
let summary = "Prevents the flow of gradients (and higher-order derivatives) by creating a new value that is detached from the original value. This is an identity operation on the primal.";
260+
let arguments = (ins AnyType:$input);
261+
let results = (outs AnyType:$output);
262+
263+
let assemblyFormat = [{
264+
$input attr-dict `:` type($input) `->` type($output)
265+
}];
266+
}
267+
268+
// Probabilistic programming
269+
def Trace : Enzyme_Type<"Trace"> {
270+
let summary = "Execution trace for probabilistic programs";
271+
let description = [{
272+
Reference: https://www.gen.dev/docs/stable/tutorials/basics/gfi/#Traces
273+
Mutable storage for mapping all sampled symbols to their values, log-likelihoods,
274+
and various other information from executing a traced probabilistic function.
275+
}];
276+
let mnemonic = "Trace";
277+
}
278+
279+
def SymbolAttr : Enzyme_Attr<"Symbol", "symbol"> {
280+
let summary = "Symbol associated with a Sample op";
281+
let description = [{
282+
Symbol associated with a Sample op.
283+
}];
284+
let parameters = (ins "uint64_t":$ptr);
285+
let assemblyFormat = "`<` $ptr `>`";
286+
}
287+
288+
def AddressAttr : TypedArrayAttrBase<SymbolAttr, "Address as an array of symbols">;
289+
def AddressArrayAttr : TypedArrayAttrBase<AddressAttr, "Array of addresses">;
290+
291+
def Constraint : Enzyme_Type<"Constraint"> {
292+
let summary = "A mutable storage mapping symbols to value constraints";
293+
let description = [{
294+
A mutable storage mapping symbols to value constraints.
295+
}];
296+
let mnemonic = "Constraint";
297+
}
298+
257299
def SampleOp : Enzyme_Op<"sample",
258300
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
259-
let summary = "Sample from a distribution. Arguments to the distribution are: a random number generator object, followed by arguments to the sample op itself";
260-
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, DefaultValuedStrAttr<StrAttr, "">:$name);
301+
let summary = "Sample from a distribution";
302+
let description = [{
303+
Sample from a distribution. By convention, the 0th operand in `inputs`
304+
or `outputs` is the initial RNG state (seed).
305+
}];
306+
307+
let arguments = (ins
308+
FlatSymbolRefAttr:$fn,
309+
Variadic<AnyType>:$inputs,
310+
OptionalAttr<FlatSymbolRefAttr>:$logpdf,
311+
OptionalAttr<SymbolAttr>:$symbol,
312+
DefaultValuedStrAttr<StrAttr, "">:$name
313+
);
314+
261315
let results = (outs Variadic<AnyType>:$outputs);
262316

263317
let assemblyFormat = [{
264318
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
265319
}];
266320
}
267321

268-
def IgnoreDerivativesOp : Enzyme_Op<"ignore_derivatives",
269-
[Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> {
270-
let summary = "Prevents the flow of gradients (and higher-order derivatives) by creating a new value that is detached from the original value. This is an identity operation on the primal.";
271-
let arguments = (ins AnyType:$input);
272-
let results = (outs AnyType:$output);
322+
def UntracedCallOp : Enzyme_Op<"untracedCall"> {
323+
let summary = "Call a probabilistic function without tracing";
324+
let description = [{
325+
Call a probabilistic function without tracing. By convention, the 0th operand in `inputs`
326+
or `outputs` is the initial RNG state (seed).
327+
}];
328+
329+
let arguments = (ins
330+
FlatSymbolRefAttr:$fn,
331+
Variadic<AnyType>:$inputs,
332+
DefaultValuedStrAttr<StrAttr, "">:$name
333+
);
334+
335+
let results = (outs Variadic<AnyType>:$outputs);
273336

274337
let assemblyFormat = [{
275-
$input attr-dict `:` type($input) `->` type($output)
338+
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
339+
}];
340+
}
341+
342+
def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
343+
let summary = "Simulate a probabilistic function to generate execution trace";
344+
let description = [{
345+
Simulate a probabilistic function to generate execution trace
346+
by replacing all SampleOps with distribution calls and recording
347+
all sampled values into the trace. This op returns the trace, the weight
348+
(accumulated log-probability), and the other outputs. By convention,
349+
the 0th operand in `inputs` or `outputs` is the initial RNG state (seed).
350+
}];
351+
352+
let arguments = (ins
353+
FlatSymbolRefAttr:$fn,
354+
Variadic<AnyType>:$inputs,
355+
DefaultValuedStrAttr<StrAttr, "">:$name
356+
);
357+
358+
let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);
359+
360+
let assemblyFormat = [{
361+
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
362+
}];
363+
}
364+
365+
def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
366+
let summary = "Generate an execution trace and weight from a probabilistic function";
367+
let description = [{
368+
Generate an execution trace and weight from a probabilistic function.
369+
If a `constraint` dict is provided AND the sample op's `symbol` is in the
370+
`constrained_symbols` array, we will use the corresponding constraint value
371+
instead of generating new samples from the probabilistic function.
372+
By convention, the 0th operand in `inputs` or `outputs` is the initial RNG
373+
state (seed).
374+
}];
375+
376+
let arguments = (ins
377+
FlatSymbolRefAttr:$fn,
378+
Variadic<AnyType>:$inputs,
379+
AddressArrayAttr:$constrained_addresses,
380+
Constraint:$constraint,
381+
DefaultValuedStrAttr<StrAttr, "">:$name
382+
);
383+
384+
let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);
385+
386+
let assemblyFormat = [{
387+
$fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type($inputs, results)
388+
}];
389+
}
390+
391+
def InitTraceOp : Enzyme_Op<"initTrace"> {
392+
let summary = "Initialize an execution trace for a probabilistic function";
393+
let description = [{
394+
Initialize an execution trace for a probabilistic function.
395+
}];
396+
let arguments = (ins );
397+
let results = (outs Trace:$trace);
398+
let assemblyFormat = "attr-dict `:` type($trace)";
399+
}
400+
401+
def AddSampleToTraceOp : Enzyme_Op<"addSampleToTrace"> {
402+
let summary = "Add a sampled value into the execution trace";
403+
let description = [{
404+
Add a sampled value into the execution trace.
405+
}];
406+
407+
let arguments = (ins
408+
Trace:$trace,
409+
SymbolAttr:$symbol,
410+
Variadic<AnyType>:$sample
411+
);
412+
413+
let results = (outs Trace:$updated_trace);
414+
415+
let assemblyFormat = [{
416+
`(` $sample `:` type($sample) `)` `into` $trace attr-dict
417+
}];
418+
}
419+
420+
def AddSubtraceOp : Enzyme_Op<"addSubtrace"> {
421+
let summary = "Insert a subtrace into a parent trace";
422+
let description = [{
423+
Insert a subtrace into a parent trace.
424+
}];
425+
426+
let arguments = (ins
427+
Trace:$subtrace,
428+
SymbolAttr:$symbol,
429+
Trace:$trace
430+
);
431+
432+
let results = (outs Trace:$updated_trace);
433+
434+
let assemblyFormat = [{
435+
$subtrace `into` $trace attr-dict
436+
}];
437+
}
438+
439+
def AddWeightToTraceOp : Enzyme_Op<"addWeightToTrace"> {
440+
let summary = "Add the aggregated weight into the execution trace";
441+
let description = [{
442+
Add the aggregated log-probability weight to the execution trace.
443+
}];
444+
445+
let arguments = (ins
446+
Trace:$trace,
447+
AnyType:$weight
448+
);
449+
450+
let results = (outs Trace:$updated_trace);
451+
452+
let assemblyFormat = [{
453+
`(` $weight `:` type($weight) `)` `into` $trace attr-dict
454+
}];
455+
}
456+
457+
def AddRetvalToTraceOp : Enzyme_Op<"addRetvalToTrace"> {
458+
let summary = "Add the function's return value(s) into the execution trace";
459+
let description = [{
460+
Add the function's return value(s) into the execution trace.
461+
}];
462+
463+
let arguments = (ins
464+
Trace:$trace,
465+
Variadic<AnyType>:$retval
466+
);
467+
468+
let results = (outs Trace:$updated_trace);
469+
470+
let assemblyFormat = [{
471+
`(` $retval `:` type($retval) `)` `into` $trace attr-dict
472+
}];
473+
}
474+
475+
def GetSampleFromConstraintOp : Enzyme_Op<"getSampleFromConstraint"> {
476+
let summary = "Get sampled values from a constraint for a given symbol";
477+
let description = [{
478+
Get sampled values from a constraint for a given symbol.
479+
}];
480+
let arguments = (ins Constraint:$constraint, SymbolAttr:$symbol);
481+
let results = (outs Variadic<AnyType>:$outputs);
482+
483+
let assemblyFormat = [{
484+
$constraint attr-dict `:` type(results)
485+
}];
486+
}
487+
488+
def GetSubconstraintOp : Enzyme_Op<"getSubconstraint"> {
489+
let summary = "Get a subconstraint from a constraint for a given symbol";
490+
let description = [{
491+
Get a subconstraint from a constraint for a given symbol.
492+
}];
493+
let arguments = (ins Constraint:$constraint, SymbolAttr:$symbol);
494+
495+
let results = (outs Constraint:$subconstraint);
496+
497+
let assemblyFormat = [{
498+
$constraint attr-dict
276499
}];
277500
}
278501

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -540,22 +540,6 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
540540
build(builder, result, resultTy, input, shapeAttr);
541541
}
542542

543-
//===----------------------------------------------------------------------===//
544-
// SampleOp
545-
//===----------------------------------------------------------------------===//
546-
547-
LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
548-
// TODO: Verify that the result type is same as the type of the referenced
549-
// func.func op.
550-
auto global =
551-
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
552-
if (!global)
553-
return emitOpError("'")
554-
<< getFn() << "' does not reference a valid global funcOp";
555-
556-
return success();
557-
}
558-
559543
/**
560544
*
561545
* Modifies activities for the AutoDiffOp.
@@ -814,3 +798,60 @@ void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
814798
MLIRContext *context) {
815799
patterns.add<ReverseRetOpt>(context);
816800
}
801+
802+
//===----------------------------------------------------------------------===//
803+
// SampleOp
804+
//===----------------------------------------------------------------------===//
805+
806+
LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
807+
// TODO: Verify that the result type is same as the type of the referenced
808+
// func.func op.
809+
auto global =
810+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
811+
if (!global)
812+
return emitOpError("'")
813+
<< getFn() << "' does not reference a valid global funcOp";
814+
815+
if (getLogpdfAttr()) {
816+
auto global = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
817+
*this, getLogpdfAttr());
818+
if (!global)
819+
return emitOpError("'")
820+
<< getLogpdf().value() << "' does not reference a valid global "
821+
<< "funcOp";
822+
}
823+
824+
return success();
825+
}
826+
827+
//===----------------------------------------------------------------------===//
828+
// GenerateOp
829+
//===----------------------------------------------------------------------===//
830+
831+
LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
832+
// TODO: Verify that the result type is same as the type of the referenced
833+
// func.func op.
834+
auto global =
835+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
836+
if (!global)
837+
return emitOpError("'")
838+
<< getFn() << "' does not reference a valid global funcOp";
839+
840+
return success();
841+
}
842+
843+
//===----------------------------------------------------------------------===//
844+
// SimulateOp
845+
//===----------------------------------------------------------------------===//
846+
847+
LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
848+
// TODO: Verify that the result type is same as the type of the referenced
849+
// func.func op.
850+
auto global =
851+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
852+
if (!global)
853+
return emitOpError("'")
854+
<< getFn() << "' does not reference a valid global funcOp";
855+
856+
return success();
857+
}

enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(MLIREnzymeAutoDiffInterface
1111
CloneFunction.cpp
1212
GradientUtils.cpp
1313
GradientUtilsReverse.cpp
14+
ProbProgUtils.cpp
1415
EnzymeLogic.cpp
1516
EnzymeLogicReverse.cpp
1617

0 commit comments

Comments
 (0)