-
Notifications
You must be signed in to change notification settings - Fork 17
feat: custom lowering #963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Need to figure out how to run the shape refinement. The public api for the shlo pass is applied to the module op (jaxtesting) avik-pal@hydra:~/reactant/Enzyme-JAX$ ./bazel-bin/enzymexlamlir-opt --resolve-custom-lowering --allow-unregistered-dialect --mlir-print-ir-after-all test/lit_tests/custom_lowering/simple.mlir
=== Lowering Map Starts ===
Op: mydialect.sin_or_cos
Function: custom_lowering1
Config:
"op": "sin"
Function: custom_lowering2
Config:
"op": "cos"
=== Lowering Map Ends ===
Checking op: mydialect.sin_or_cos
✔ Matched lowering: custom_lowering1
Wrapper name: custom_lowering10
Checking op: mydialect.sin_or_cos
✔ Matched lowering: custom_lowering2
Wrapper name: custom_lowering21
=== Lowered All Ops ===
test/lit_tests/custom_lowering/simple.mlir:17:10: error: 'func.call' op operand type mismatch: expected operand type 'tensor<?x?xf32>', but provided 'tensor<8x8xf32>' for operand number 0
%0 = "mydialect.sin_or_cos"(%arg0) {
^
test/lit_tests/custom_lowering/simple.mlir:17:10: note: see current operation: %0 = "func.call"(%arg0) <{callee = @custom_lowering10}> : (tensor<8x8xf32>) -> tensor<8x8xf32>
// -----// IR Dump After ResolveCustomLoweringPass Failed (resolve-custom-lowering) //----- //
"builtin.module"() ({
"func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering21", sym_visibility = "private"}> ({
^bb0(%arg4: tensor<?x?xf32>):
%7 = "stablehlo.cosine"(%arg4) : (tensor<?x?xf32>) -> tensor<?x?xf32>
"func.return"(%7) : (tensor<?x?xf32>) -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering10", sym_visibility = "private"}> ({
^bb0(%arg3: tensor<?x?xf32>):
%6 = "stablehlo.sine"(%arg3) : (tensor<?x?xf32>) -> tensor<?x?xf32>
"func.return"(%6) : (tensor<?x?xf32>) -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering1"}> ({
^bb0(%arg2: tensor<?x?xf32>):
%5 = "stablehlo.sine"(%arg2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
"func.return"(%5) : (tensor<?x?xf32>) -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering2"}> ({
^bb0(%arg1: tensor<?x?xf32>):
%4 = "stablehlo.cosine"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
"func.return"(%4) : (tensor<?x?xf32>) -> ()
}) : () -> ()
"enzymexla.lowering.register"() <{config = {op = "sin"}, fn = @custom_lowering1, op_name = "mydialect.sin_or_cos"}> : () -> ()
"enzymexla.lowering.register"() <{config = {op = "cos"}, fn = @custom_lowering2, op_name = "mydialect.sin_or_cos"}> : () -> ()
"func.func"() <{function_type = (tensor<8x8xf32>) -> tensor<4x4xf32>, sym_name = "main"}> ({
^bb0(%arg0: tensor<8x8xf32>):
%0 = "func.call"(%arg0) <{callee = @custom_lowering10}> : (tensor<8x8xf32>) -> tensor<8x8xf32>
%1 = "stablehlo.slice"(%0) <{limit_indices = array<i64: 4, 4>, start_indices = array<i64: 0, 0>, strides = array<i64: 1, 1>}> : (tensor<8x8xf32>) -> tensor<4x4xf32>
%2 = "func.call"(%1) <{callee = @custom_lowering21}> : (tensor<4x4xf32>) -> tensor<4x4xf32>
%3 = "stablehlo.add"(%2, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
"func.return"(%2) : (tensor<4x4xf32>) -> ()
}) : () -> ()
}) : () -> () |
ok I think I have a nice solution |
module {
func.func private @custom_lowering2__1(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c = stablehlo.constant dense<4> : tensor<2xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
%1 = stablehlo.cosine %0 : tensor<?x?xf32>
%2 = "enzymexla.lowering.shape_refinement"(%1) : (tensor<?x?xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
func.func private @custom_lowering1__0(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%c = stablehlo.constant dense<8> : tensor<2xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<8x8xf32>, tensor<2xi64>) -> tensor<?x?xf32>
%1 = stablehlo.sine %0 : tensor<?x?xf32>
%2 = "enzymexla.lowering.shape_refinement"(%1) : (tensor<?x?xf32>) -> tensor<8x8xf32>
return %2 : tensor<8x8xf32>
}
func.func @custom_lowering1(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = stablehlo.sine %arg0 : tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
func.func @custom_lowering2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = stablehlo.cosine %arg0 : tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
func.func @main(%arg0: tensor<8x8xf32>) -> tensor<4x4xf32> {
%0 = call @custom_lowering1__0(%arg0) : (tensor<8x8xf32>) -> tensor<8x8xf32>
%1 = stablehlo.slice %0 [0:4, 0:4] : (tensor<8x8xf32>) -> tensor<4x4xf32>
%2 = call @custom_lowering2__1(%1) : (tensor<4x4xf32>) -> tensor<4x4xf32>
%3 = stablehlo.add %2, %1 : tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
} |
2c90e8f
to
f3c9907
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ftynse for thoughts
Didn't do a detailed review, only high-level thoughts. The main problem is type co/contra-variance on function boundaries. There are two general approaches:
|
No description provided.