-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR] Add apply_patterns.vector.arm_neon.contraction_to_i8mm TD Op #140251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR] Add apply_patterns.vector.arm_neon.contraction_to_i8mm TD Op #140251
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Momchil Velikov (momchil-velikov) ChangesThis patch wraps It also removes the "test-lower-to-arm-neon" pass. Full diff: https://github.com/llvm/llvm-project/pull/140251.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
index 1c679bcd049b8..3de3ec3f3a0e8 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
@@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
new file mode 100644
index 0000000000000..5bc03535a86c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmNeon Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_neon {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_neon
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
new file mode 100644
index 0000000000000..f863ccaea3765
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmNeonTransformOps.td - Arm Neon transform ops------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMNEON_TRANSFORM_OPS
+#define ARMNEON_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmNeonLowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_neon.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmNeon dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMNEON_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..b8bc72a2bb734
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
+mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..619ac88ad76d3 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b096c2cbc503f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmNeonLowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmNeonVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmNeonVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmNeonVectorTransformDialectExtension)
+
+ ArmNeonVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_neon::ArmNeonDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
+
+void mlir::arm_neon::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmNeonVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..69d2143ad4e1f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_dialect_library(MLIRArmNeonVectorTransformOps
+ ArmNeonVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps
+
+ DEPENDS
+ MLIRArmNeonVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmNeonDialect
+ MLIRArmNeonTransforms
+ )
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index 297be91e77283..ccad307e89dfb 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -transform-interpreter %s | FileCheck %s
// CHECK-LABEL: vector_arm_neon_mixed_types
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
@@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
return %res : vector<1x2xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_neon.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
deleted file mode 100644
index 460842d238533..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-# Exclude tests from libMLIR.so
-add_mlir_library(MLIRArmNeonTestPasses
- TestLowerToArmNeon.cpp
-
- EXCLUDE_FROM_LIBMLIR
- )
-mlir_target_link_libraries(MLIRArmNeonTestPasses PUBLIC
- MLIRArmNeonDialect
- MLIRArmNeonTransforms
- MLIRIR
- MLIRPass
- MLIRTransforms
- )
diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
deleted file mode 100644
index 03c80b601a347..0000000000000
--- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-//===- TestLowerToArmNeon.cpp - Test lowering to ArmNeon as a sink pass -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass for testing the lowering to ArmNeon as a
-// generally usable sink pass.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmNeon/Transforms.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define PASS_NAME "test-lower-to-arm-neon"
-
-using namespace mlir;
-using namespace mlir::arm_neon;
-
-namespace {
-struct TestLowerToArmNeon
- : public PassWrapper<TestLowerToArmNeon, OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToArmNeon)
-
- StringRef getArgument() const final { return PASS_NAME; }
- StringRef getDescription() const final { return "Tests lower to arm Neon."; }
- TestLowerToArmNeon() = default;
- TestLowerToArmNeon(const TestLowerToArmNeon &pass) = default;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arm_neon::ArmNeonDialect>();
- }
-
- void runOnOperation() override;
-};
-
-} // namespace
-
-void TestLowerToArmNeon::runOnOperation() {
- MLIRContext *context = &getContext();
- RewritePatternSet patterns(context);
- populateLowerContractionToSMMLAPatternPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
- return signalPassFailure();
-}
-
-namespace mlir {
-namespace test {
-
-void registerTestLowerToArmNeon() { PassRegistration<TestLowerToArmNeon>(); }
-
-} // namespace test
-} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index a8fd70e6397a5..5614237d80f02 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
-add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)
add_subdirectory(Bufferization)
add_subdirectory(ControlFlow)
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 3220dca282eac..5256cf7ae90d7 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
- MLIRArmNeonTestPasses
MLIRArmSMETestPasses
MLIRBufferizationTestPasses
MLIRControlFlowTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index cdcf59b2add13..aa9c33dd9150c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
-void registerTestLowerToArmNeon();
void registerTestLowerToArmSME();
void registerTestLowerToLLVM();
void registerTestMakeIsolatedFromAbovePass();
@@ -264,7 +263,6 @@ void registerTestPasses() {
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
- mlir::test::registerTestLowerToArmNeon();
mlir::test::registerTestLowerToArmSME();
mlir::test::registerTestLowerToLLVM();
mlir::test::registerTestMakeIsolatedFromAbovePass();
|
This patch wraps `populateLowerContractionToSMMLAPatternPatterns` into a new TD Op `apply_patterns.vector.arm_neon.lower_contraction`. It also removes the "test-lower-to-arm-neon" pass.
60d4192
to
f9d5ad1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
I've left a couple of small suggestion inline. One larger suggestion - I would drop vector
from the Op name. To me, this is the naming scheme:
apply_patterns.{dialect}.{meta-name-for-patterns}
This will help us avoid "polluting" the "Vector" namespace with things specific to hardware targets.
mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
Outdated
Show resolved
Hide resolved
In that case the TD op has to be simply IMHO, having target dependent components in pattern names (like we have |
The naming convention so far is to have |
Why not |
Because it's less descriptive - there's no indication on which dialect it operates. It also pollutes the "apply_patterns" namespace (considering "lower_contraction" is a generic name that does not allow unprincipled dubious implications like "oh, we have contraction only on vectors, so it must be operating on the vector dialect"). Why Why not |
This patch wraps
populateLowerContractionToSMMLAPatternPatterns
into a new TD Opapply_patterns.vector.arm_neon.contraction_to_i8mm
.It also removes the "test-lower-to-arm-neon" pass.