Skip to content

[mlir][c] Expose AsmState. #66693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ extern "C" {
}; \
typedef struct name name

DEFINE_C_API_STRUCT(MlirAsmState, void);
DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
Expand Down Expand Up @@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
MLIR_CAPI_EXPORTED void
mlirOperationStateEnableResultTypeInference(MlirOperationState *state);

//===----------------------------------------------------------------------===//
// AsmState API.
// While many of these are simple settings that could be represented in a
// struct, they are wrapped in a heap allocated object and accessed via
// functions to maximize the possibility of compatibility over time.
//===----------------------------------------------------------------------===//

/// Creates new AsmState, as with AsmState the IR should not be mutated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should this say maybe "as with OperationState"? Not sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow, OperationState is created while mutating the IR.

/// in-between using this state.
/// Must be freed with a call to mlirAsmStateDestroy().
// TODO: This should be expanded to handle location & resouce map.
MLIR_CAPI_EXPORTED MlirAsmState
mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags);

/// Creates new AsmState from value.
/// Must be freed with a call to mlirAsmStateDestroy().
// TODO: This should be expanded to handle location & resouce map.
MLIR_CAPI_EXPORTED MlirAsmState
mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags);

/// Destroys printing flags created with mlirAsmStateCreate.
MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state);

//===----------------------------------------------------------------------===//
// Op Printing flags API.
// While many of these are simple settings that could be represented in a
Expand Down Expand Up @@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);

/// Prints a value as an operand (i.e., the ValueID).
MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
MlirOpPrintingFlags flags,
MlirAsmState state,
MlirStringCallback callback,
void *userData);

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/CAPI/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"

DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(),
MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
printAccum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
mlirAsmStateDestroy(state);
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
Expand Down
49 changes: 47 additions & 2 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
delete unwrap(registry);
}

//===----------------------------------------------------------------------===//
// AsmState API.
//===----------------------------------------------------------------------===//

MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op,
MlirOpPrintingFlags flags) {
return wrap(new AsmState(unwrap(op), *unwrap(flags)));
}

static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
do {
// If we are printing local scope, stop at the first operation that is
// isolated from above.
if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;

// Otherwise, traverse up to the next parent.
Operation *parentOp = op->getParentOp();
if (!parentOp)
break;
op = parentOp;
} while (true);
return op;
}

MlirAsmState mlirAsmStateCreateForValue(MlirValue value,
MlirOpPrintingFlags flags) {
Operation *op;
mlir::Value val = unwrap(value);
if (auto result = llvm::dyn_cast<OpResult>(val)) {
op = result.getOwner();
} else {
op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp();
if (!op) {
emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>";
return {nullptr};
}
}
op = findParent(op, unwrap(flags)->shouldUseLocalScope());
return wrap(new AsmState(op, *unwrap(flags)));
}

/// Destroys printing flags created with mlirAsmStateCreate.
void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }

//===----------------------------------------------------------------------===//
// Printing flags API.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
unwrap(value).print(stream);
}

void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags,
void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
Value cppValue = unwrap(value);
cppValue.printAsOperand(stream, *unwrap(flags));
cppValue.printAsOperand(stream, *unwrap(state));
}

MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
// clang-format on

MlirAsmState state = mlirAsmStateCreateForOperation(parentOperation, flags);
fprintf(stderr, "With state: |");
mlirValuePrintAsOperand(value, state, printToStderr, NULL);
// CHECK: With state: |%0|
fprintf(stderr, "|\n");
mlirAsmStateDestroy(state);

mlirOpPrintingFlagsDestroy(flags);
}

Expand Down