Skip to content

[mlir][c] Expose AsmState. #66693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023
Merged

Conversation

jpienaar
Copy link
Member

Enable usage where capturing AsmState is good. Haven't plumbed through to python yet. This also only changes one C API to verify plumbing.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir

Changes

Enable usage where capturing AsmState is good. Haven't plumbed through to python yet. This also only changes one C API to verify plumbing.


Full diff: https://github.com/llvm/llvm-project/pull/66693.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+23)
  • (modified) mlir/include/mlir/CAPI/IR.h (+1)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+19)
  • (modified) mlir/test/CAPI/ir.c (+7)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..8ed126ad1775760 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -48,6 +48,7 @@ extern "C" {
   };                                                                           \
   typedef struct name name
 
+DEFINE_C_API_STRUCT(MlirAsmState, void);
 DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void);
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
@@ -383,6 +384,23 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
 MLIR_CAPI_EXPORTED void
 mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
 
+//===----------------------------------------------------------------------===//
+// AsmState API.
+// While many of these are simple settings that could be represented in a
+// struct, they are wrapped in a heap allocated object and accessed via
+// functions to maximize the possibility of compatibility over time.
+//===----------------------------------------------------------------------===//
+
+/// Creates new AsmState, as with AsmState the IR should not be mutated
+/// in-between using this state.
+/// Must be freed with a call to mlirAsmStateDestroy().
+// TODO: This should be expanded to handle location & resouce map.
+MLIR_CAPI_EXPORTED MlirAsmState mlirAsmStateCreate(MlirOperation op,
+                                                   MlirOpPrintingFlags flags);
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state);
+
 //===----------------------------------------------------------------------===//
 // Op Printing flags API.
 // While many of these are simple settings that could be represented in a
@@ -819,6 +837,11 @@ MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
                                                 MlirStringCallback callback,
                                                 void *userData);
 
+/// Prints a value as an operand (i.e., the ValueID) using prepopulated state.
+MLIR_CAPI_EXPORTED void
+mlirValuePrintAsOperandWithState(MlirValue value, MlirAsmState state,
+                                 MlirStringCallback callback, void *userData);
+
 /// Returns an op operand representing the first use of the value, or a null op
 /// operand if there are no uses.
 MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index b8ccec896c27ba5..1836cb0acb67e7e 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -21,6 +21,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 
+DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState)
 DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig)
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..2a4db78c9e664c3 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -138,6 +138,17 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
   delete unwrap(registry);
 }
 
+//===----------------------------------------------------------------------===//
+// AsmState API.
+//===----------------------------------------------------------------------===//
+
+MlirAsmState mlirAsmStateCreate(MlirOperation op, MlirOpPrintingFlags flags) {
+  return wrap(new AsmState(unwrap(op), *unwrap(flags)));
+}
+
+/// Destroys printing flags created with mlirAsmStateCreate.
+void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); }
+
 //===----------------------------------------------------------------------===//
 // Printing flags API.
 //===----------------------------------------------------------------------===//
@@ -585,6 +596,14 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
   unwrap(op)->print(stream, *unwrap(flags));
 }
 
+void mlirValuePrintAsOperandWithState(MlirValue value, MlirAsmState state,
+                                      MlirStringCallback callback,
+                                      void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  Value cppValue = unwrap(value);
+  cppValue.printAsOperand(stream, *unwrap(state));
+}
+
 void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
                                 void *userData) {
   detail::CallbackOstream stream(callback, userData);
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5725d05a3e132f7..08d7512b8c4b259 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -487,6 +487,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
   // clang-format on
 
+  MlirAsmState state = mlirAsmStateCreate(parentOperation, flags);
+  fprintf(stderr, "With state: |");
+  mlirValuePrintAsOperandWithState(value, state, printToStderr, NULL);
+  // CHECK: With state: |%0|
+  fprintf(stderr, "|\n");
+  mlirAsmStateDestroy(state);
+
   mlirOpPrintingFlagsDestroy(flags);
 }
 

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo small nits/questions.

// functions to maximize the possibility of compatibility over time.
//===----------------------------------------------------------------------===//

/// Creates new AsmState, as with AsmState the IR should not be mutated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should this say maybe "as with OperationState"? Not sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow, OperationState is created while mutating the IR.

@@ -819,6 +837,11 @@ MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value,
MlirStringCallback callback,
void *userData);

/// Prints a value as an operand (i.e., the ValueID) using prepopulated state.
MLIR_CAPI_EXPORTED void
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to unify this with mlirValuePrintAsOperand? Maybe awkward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I wish overloading was a thing. Problem is as AsmState is initialized with OpPrintingFlags, if you pass in both is confusing ... that being said AsmState is the better one to pass in, currently its being created in opaque way. Tagged union seems the best I could do here if I wanted to unify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went the other way: I made this the call signature and made it easier to construct the original using building blocks.

AsmState usage here subsumes OpPrintingFlags one and allows for much more efficient printing (avoiding the quadratic behavior if one walks & prints and ends up recreating the state over and over). Also makes it much clearer that there is a cost rather then hiding it. I had to add one helper to do so and changed the corresponding Python interface. I think probably all of the other print ones should be similarly changed, but wanted to check this change first.

@jpienaar jpienaar force-pushed the piper_export_cl_566384869 branch from d68b3ef to 49b4701 Compare September 19, 2023 00:45
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ug, seems formatter got creative. I'll revert this file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol G style guide strikes again

@jpienaar jpienaar force-pushed the piper_export_cl_566384869 branch from 49b4701 to e331c03 Compare September 19, 2023 00:50
Enable usage where capturing AsmState is good. Haven't plumbed through to python yet.
@jpienaar jpienaar force-pushed the piper_export_cl_566384869 branch from e331c03 to 18722cd Compare September 19, 2023 01:06
@jpienaar jpienaar merged commit 31ebe98 into llvm:main Sep 19, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
Enable usage where capturing AsmState is good (e.g., avoiding creating AsmState over and over again when walking IR and printing).

This also only changes one C API to verify plumbing. But using the AsmState makes the cost more explicit than the flags interface (which hides the traversals and construction here) and also enables a more efficient usage C side.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants