Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ebfbc03
separating out things
sbrantq Mar 21, 2025
a479207
Add trace op
sbrantq Mar 23, 2025
bdd5957
clean up
sbrantq Mar 23, 2025
bcb6943
format
sbrantq Mar 23, 2025
a436e25
rename
sbrantq Mar 23, 2025
962568c
improve op
sbrantq Mar 25, 2025
716f30c
improve test
sbrantq Mar 25, 2025
be7a16b
Force logpdf func
sbrantq Mar 26, 2025
b9d724b
verify logpdf
sbrantq Mar 26, 2025
3ed105e
save
sbrantq Apr 4, 2025
402caa9
simulate op
sbrantq Apr 10, 2025
8eb4eba
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 10, 2025
a77b141
suppress
sbrantq Apr 10, 2025
094327a
fix
sbrantq Apr 10, 2025
ded5990
condOp
sbrantq Apr 10, 2025
e3953aa
Bayesian linear regression test
sbrantq Apr 10, 2025
b70a6e6
Refactor stuff to use probprogutils
sbrantq Apr 18, 2025
38f2b85
scf test (not working)
sbrantq Apr 18, 2025
971da10
remove logpdf
sbrantq Apr 22, 2025
dd01e86
direct pp model call
sbrantq Apr 23, 2025
fd5084b
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 24, 2025
447201e
fix build
sbrantq Apr 24, 2025
8333afc
fix up
sbrantq Apr 25, 2025
6511a8d
merge
sbrantq Apr 30, 2025
36691e5
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq Apr 30, 2025
707b882
simple generate test
sbrantq May 2, 2025
0128e0f
remove call test
sbrantq May 14, 2025
4107d2e
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into sampleop
sbrantq May 17, 2025
94c684f
undo merge hack
sbrantq May 17, 2025
03f3034
clean up and fix roundtrip test
sbrantq May 21, 2025
8efe677
format
sbrantq May 21, 2025
050f67d
Passing symbol ptr to sample ops too. Adding simulate op
sbrantq May 27, 2025
385ed5d
Merge commit 'c3a95700012478aaac5c8a2a898e56a568f4769f' into probprog…
sbrantq May 31, 2025
be4894d
change to anytype
sbrantq Jun 1, 2025
8b56126
asm format for initTrace
sbrantq Jun 1, 2025
c4afd9e
fix name
sbrantq Jun 1, 2025
adb7b72
asm format for addSampleToTraceOp
sbrantq Jun 2, 2025
825b93b
type preserving hacks
sbrantq Jun 2, 2025
bda0d6f
Refactor: making pointers ui64 attributes, updating tests
sbrantq Jun 10, 2025
989ff23
remove initTraceOp
sbrantq Jun 10, 2025
614821e
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jun 13, 2025
5729a66
traced_output_indices attr to specify which output to trace
sbrantq Jun 14, 2025
23bfd0d
Real generate op without recusive support
sbrantq Jun 23, 2025
3e6f8c0
simplify
sbrantq Jun 23, 2025
72ab30e
encoding constraints as a custom attribute
sbrantq Jun 24, 2025
7ed5126
simplify attr
sbrantq Jun 24, 2025
e162dfe
add checks
sbrantq Jun 25, 2025
453312c
postpasses
sbrantq Jun 25, 2025
16b819d
bug fix for aliasing outputs
sbrantq Jun 26, 2025
56560f3
more tests
sbrantq Jun 26, 2025
3390fd0
untraced call
sbrantq Jun 26, 2025
de39226
Merge branch 'main' of https://github.com/EnzymeAD/Enzyme into probpr…
sbrantq Jun 27, 2025
d6bfd3c
Merge branch 'main' into probprog-simulate
sbrantq Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,114 @@ def BroadcastOp : Enzyme_Op<"broadcast"> {

def SampleOp : Enzyme_Op<"sample",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Sample from a distribution. Arguments to the distribution are: a random number generator object, followed by arguments to the sample op itself";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, DefaultValuedStrAttr<StrAttr, "">:$name);
let summary = "Sample from a distribution";
let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
OptionalAttr<FlatSymbolRefAttr>:$logpdf,
OptionalAttr<UI64Attr>:$symbol,
OptionalAttr<DenseI64ArrayAttr>:$traced_input_indices,
OptionalAttr<DenseI64ArrayAttr>:$traced_output_indices,
OptionalAttr<DenseI64ArrayAttr>:$alias_map,
DefaultValuedStrAttr<StrAttr, "">:$name
);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Simulate a probabilistic function to generate execution trace";
let description = [{
Simulate a probabilistic function to generate execution trace
by replacing all SampleOps with distribution calls and inserting
sampled values into the choice map.
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
UI64Attr:$trace,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def AddSampleToTraceOp : Enzyme_Op<"addSampleToTrace"> {
let summary = "Add a sampled value into the execution trace";
let description = [{
Add a sampled value into the execution trace.
}];

let arguments = (ins
Variadic<AnyType>:$sample,
UI64Attr:$symbol,
UI64Attr:$trace,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs );

let assemblyFormat = [{
$sample attr-dict `:` functional-type($sample, results)
}];
}

def ConstraintAttr : Enzyme_Attr<"Constraint", "constraint"> {
let summary = "In probabilistic programming, mapping a symbol to a constraint value.";
let description = [{
This attribute represents a symbol/value pair inside the constraints dict
optionally attached to `enzyme.generate` op. The `symbol` field should be
an ui64. The `values` field is an array attribute containing constraint values
in the same order as the SampleOp's traced output indices.
}];

let parameters = (ins "uint64_t":$symbol, "ArrayAttr":$values);

let assemblyFormat = "`<` `symbol` `=` $symbol `,` `values` `=` $values `>`";
}

def UntracedCallOp : Enzyme_Op<"untracedCall"> {
let summary = "Call a probabilistic function without tracing";
let description = [{
Call a probabilistic function without tracing.
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Generate an execution trace and weight from a probabilistic function";
let description = [{
Generate an execution trace and weight from a probabilistic function. If a constraint dict
is provided, we will use the corresponding constraint values instead of generating new samples.
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
UI64Attr:$trace,
OptionalAttr<ArrayAttr>:$constraints,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
Expand Down
73 changes: 57 additions & 16 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,6 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
build(builder, result, resultTy, input, shapeAttr);
}

//===----------------------------------------------------------------------===//
// SampleOp
//===----------------------------------------------------------------------===//

LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

/**
*
* Modifies activities for the AutoDiffOp.
Expand Down Expand Up @@ -713,3 +697,60 @@ void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ReverseRetOpt>(context);
}

//===----------------------------------------------------------------------===//
// SampleOp
//===----------------------------------------------------------------------===//

LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

if (getLogpdfAttr()) {
auto global = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
*this, getLogpdfAttr());
if (!global)
return emitOpError("'")
<< getLogpdf().value() << "' does not reference a valid global "
<< "funcOp";
}

return success();
}

//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//

LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

//===----------------------------------------------------------------------===//
// SimulateOp
//===----------------------------------------------------------------------===//

LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ add_mlir_library(MLIREnzymeAutoDiffInterface
CloneFunction.cpp
GradientUtils.cpp
GradientUtilsReverse.cpp
ProbProgUtils.cpp
EnzymeLogic.cpp
EnzymeLogicReverse.cpp
EnzymeLogicProbProg.cpp

DEPENDS
MLIRAutoDiffOpInterfaceIncGen
Expand Down
68 changes: 68 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/ProbProgUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===- ProbProgUtils.cpp - Utilities for probprog interfaces
//--------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Interfaces/ProbProgUtils.h"
#include "Dialect/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

// TODO: this shouldn't depend on specific dialects except Enzyme.
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

#include "CloneFunction.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/BreadthFirstIterator.h"

using namespace mlir;
using namespace mlir::enzyme;

MProbProgUtils *MProbProgUtils::CreateFromClone(FunctionOpInterface toeval,
MProbProgMode mode) {
if (toeval.getFunctionBody().empty()) {
llvm::errs() << toeval << "\n";
llvm_unreachable("Creating MProbProgUtils from empty function");
}

std::string suffix;

switch (mode) {
case MProbProgMode::Call:
suffix = "call";
break;
case MProbProgMode::Generate:
suffix = "generate";
break;
case MProbProgMode::Simulate:
suffix = "simulate";
break;
default:
llvm_unreachable("Invalid MProbProgMode\n");
}

OpBuilder builder(toeval.getContext());

auto NewF = cast<FunctionOpInterface>(toeval->cloneWithoutRegions());
SymbolTable::setSymbolName(NewF, toeval.getName().str() + "." + suffix);
NewF.setType(toeval.getFunctionType());

Operation *parent = toeval->getParentWithTrait<OpTrait::SymbolTable>();
SymbolTable table(parent);
table.insert(NewF);

IRMapping originalToNew;
std::map<Operation *, Operation *> originalToNewOps;
cloneInto(&toeval.getFunctionBody(), &NewF.getFunctionBody(), originalToNew,
originalToNewOps);

return new MProbProgUtils(NewF, toeval, originalToNew, originalToNewOps,
mode);
}
53 changes: 53 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/ProbProgUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- ProbProgUtils.h - Utilities for probprog interfaces -------* C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENZYME_MLIR_INTERFACES_PROBPROG_UTILS_H
#define ENZYME_MLIR_INTERFACES_PROBPROG_UTILS_H

#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

#include "CloneFunction.h"
#include "EnzymeLogic.h"

#include <functional>

namespace mlir {
namespace enzyme {

class MProbProgUtils {
public:
FunctionOpInterface newFunc;

MProbProgMode mode;
FunctionOpInterface oldFunc;
IRMapping originalToNewFn;
std::map<Operation *, Operation *> originalToNewFnOps;

private:
Block *initializationBlock;

public:
MProbProgUtils(FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
MProbProgMode mode_)
: newFunc(newFunc_), mode(mode_), oldFunc(oldFunc_),
originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_),
initializationBlock(&*(newFunc.getFunctionBody().begin())) {}

static MProbProgUtils *CreateFromClone(FunctionOpInterface toeval,
MProbProgMode mode);
};

} // namespace enzyme
} // namespace mlir

#endif // ENZYME_MLIR_INTERFACES_PROBPROG_UTILS_H
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_doc(Passes EnzymePasses ./ -gen-pass-doc)

add_mlir_dialect_library(MLIREnzymeTransforms
EnzymeMLIRPass.cpp
ProbProgMLIRPass.cpp
EnzymeBatchPass.cpp
EnzymeWrapPass.cpp
PrintActivityAnalysis.cpp
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DominanceInfo;
namespace enzyme {
std::unique_ptr<Pass> createDifferentiatePass();

std::unique_ptr<Pass> createProbProgPass();

std::unique_ptr<Pass> createBatchPass();

std::unique_ptr<Pass> createDifferentiateWrapperPass();
Expand Down
21 changes: 21 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ def DifferentiatePass : Pass<"enzyme"> {
let constructor = "mlir::enzyme::createDifferentiatePass()";
}

def ProbProgPass : Pass<"probprog"> {
let summary = "ProbProg Passes";
let dependentDialects = [
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"tensor::TensorDialect",
"enzyme::EnzymeDialect",
];
let options = [
Option<
/*C++ variable name=*/"postpasses",
/*CLI argument=*/"postpasses",
/*type=*/"std::string",
/*default=*/"",
/*description=*/"Optimization passes to apply to generated probabilistic programs"
>,
];
let constructor = "mlir::enzyme::createProbProgPass()";
}

def BatchPass : Pass<"enzyme-batch"> {
let summary = "Batch Passes";
let dependentDialects = [
Expand Down
Loading
Loading