Skip to content

feat: custom lowering #963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

feat: custom lowering #963

wants to merge 6 commits into from

Conversation

avik-pal
Copy link
Collaborator

No description provided.

@avik-pal
Copy link
Collaborator Author

Need to figure out how to run the shape refinement. The public api for the shlo pass is applied to the module op

(jaxtesting) avik-pal@hydra:~/reactant/Enzyme-JAX$ ./bazel-bin/enzymexlamlir-opt --resolve-custom-lowering --allow-unregistered-dialect --mlir-print-ir-after-all test/lit_tests/custom_lowering/simple.mlir 
=== Lowering Map Starts ===

Op: mydialect.sin_or_cos
  Function: custom_lowering1
  Config:
    "op": "sin"
  Function: custom_lowering2
  Config:
    "op": "cos"

=== Lowering Map Ends ===

Checking op: mydialect.sin_or_cosMatched lowering: custom_lowering1
Wrapper name: custom_lowering10
Checking op: mydialect.sin_or_cosMatched lowering: custom_lowering2
Wrapper name: custom_lowering21
=== Lowered All Ops ===

test/lit_tests/custom_lowering/simple.mlir:17:10: error: 'func.call' op operand type mismatch: expected operand type 'tensor<?x?xf32>', but provided 'tensor<8x8xf32>' for operand number 0
    %0 = "mydialect.sin_or_cos"(%arg0) {
         ^
test/lit_tests/custom_lowering/simple.mlir:17:10: note: see current operation: %0 = "func.call"(%arg0) <{callee = @custom_lowering10}> : (tensor<8x8xf32>) -> tensor<8x8xf32>
// -----// IR Dump After ResolveCustomLoweringPass Failed (resolve-custom-lowering) //----- //
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering21", sym_visibility = "private"}> ({
  ^bb0(%arg4: tensor<?x?xf32>):
    %7 = "stablehlo.cosine"(%arg4) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    "func.return"(%7) : (tensor<?x?xf32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering10", sym_visibility = "private"}> ({
  ^bb0(%arg3: tensor<?x?xf32>):
    %6 = "stablehlo.sine"(%arg3) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    "func.return"(%6) : (tensor<?x?xf32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering1"}> ({
  ^bb0(%arg2: tensor<?x?xf32>):
    %5 = "stablehlo.sine"(%arg2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    "func.return"(%5) : (tensor<?x?xf32>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<?x?xf32>) -> tensor<?x?xf32>, sym_name = "custom_lowering2"}> ({
  ^bb0(%arg1: tensor<?x?xf32>):
    %4 = "stablehlo.cosine"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
    "func.return"(%4) : (tensor<?x?xf32>) -> ()
  }) : () -> ()
  "enzymexla.lowering.register"() <{config = {op = "sin"}, fn = @custom_lowering1, op_name = "mydialect.sin_or_cos"}> : () -> ()
  "enzymexla.lowering.register"() <{config = {op = "cos"}, fn = @custom_lowering2, op_name = "mydialect.sin_or_cos"}> : () -> ()
  "func.func"() <{function_type = (tensor<8x8xf32>) -> tensor<4x4xf32>, sym_name = "main"}> ({
  ^bb0(%arg0: tensor<8x8xf32>):
    %0 = "func.call"(%arg0) <{callee = @custom_lowering10}> : (tensor<8x8xf32>) -> tensor<8x8xf32>
    %1 = "stablehlo.slice"(%0) <{limit_indices = array<i64: 4, 4>, start_indices = array<i64: 0, 0>, strides = array<i64: 1, 1>}> : (tensor<8x8xf32>) -> tensor<4x4xf32>
    %2 = "func.call"(%1) <{callee = @custom_lowering21}> : (tensor<4x4xf32>) -> tensor<4x4xf32>
    %3 = "stablehlo.add"(%2, %1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
    "func.return"(%2) : (tensor<4x4xf32>) -> ()
  }) : () -> ()
}) : () -> ()

@avik-pal
Copy link
Collaborator Author

ok I think I have a nice solution

@avik-pal
Copy link
Collaborator Author

avik-pal commented May 18, 2025

module {
  func.func private @custom_lowering2__1(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
    %c = stablehlo.constant dense<4> : tensor<2xi64>
    %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
    %1 = stablehlo.cosine %0 : tensor<?x?xf32>
    %2 = "enzymexla.lowering.shape_refinement"(%1) : (tensor<?x?xf32>) -> tensor<4x4xf32>
    return %2 : tensor<4x4xf32>
  }
  func.func private @custom_lowering1__0(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
    %c = stablehlo.constant dense<8> : tensor<2xi64>
    %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<8x8xf32>, tensor<2xi64>) -> tensor<?x?xf32>
    %1 = stablehlo.sine %0 : tensor<?x?xf32>
    %2 = "enzymexla.lowering.shape_refinement"(%1) : (tensor<?x?xf32>) -> tensor<8x8xf32>
    return %2 : tensor<8x8xf32>
  }
  func.func @custom_lowering1(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %0 = stablehlo.sine %arg0 : tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
  func.func @custom_lowering2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %0 = stablehlo.cosine %arg0 : tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
  func.func @main(%arg0: tensor<8x8xf32>) -> tensor<4x4xf32> {
    %0 = call @custom_lowering1__0(%arg0) : (tensor<8x8xf32>) -> tensor<8x8xf32>
    %1 = stablehlo.slice %0 [0:4, 0:4] : (tensor<8x8xf32>) -> tensor<4x4xf32>
    %2 = call @custom_lowering2__1(%1) : (tensor<4x4xf32>) -> tensor<4x4xf32>
    %3 = stablehlo.add %2, %1 : tensor<4x4xf32>
    return %2 : tensor<4x4xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/simple_lowering branch from 2c90e8f to f3c9907 Compare May 18, 2025 23:51
@avik-pal avik-pal changed the title feat: initial setup for custom lowerings from frontend feat: custom lowering May 18, 2025
@avik-pal avik-pal requested a review from wsmoses May 18, 2025 23:51
Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ftynse for thoughts

@ftynse
Copy link
Collaborator

ftynse commented May 19, 2025

Didn't do a detailed review, only high-level thoughts.

The main problem is type co/contra-variance on function boundaries. There are two general approaches:

  1. inject casts at boundaries + have a post-pass propagating static shape information throughout the "lowered"/inlined code; we can have an interface or another logic attached to the tensor type itself (and potentially other types if needed) that checks if types are co/contra-variant and can construct the corresponding cast, tensor.cast in this case; most cast-like operations have a areCastCompatible method for verification purposes that we can also use upfront;
  2. recreate operations programmatically instead of cloning/inlining relying on interfaces; many operations implement InferTypeOpInterface (and more interface implementations can be added without changing upstream), we can have C++ logic that maintains the mapping between "pattern" values and actual values, looks at the IR to extract operation kind and non-system attributes, then calls the interface to deduce result types given actual value types, which in turn makes it possible to construct a copy of the operation with specific types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants