Skip to content

Commit 60d4192

Browse files
[MLIR] Add apply_patterns.vector.arm_neon.lower_contraction TD Op
This patch wraps `populateLowerContractionToSMMLAPatternPatterns` into a new TD Op `apply_patterns.vector.arm_neon.lower_contraction`. It also removes the "test-lower-to-arm-neon" pass.
1 parent 9273091 commit 60d4192

File tree

14 files changed

+153
-78
lines changed

14 files changed

+153
-78
lines changed

mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
44
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
55
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
66
add_public_tablegen_target(MLIRArmNeonConversionsIncGen)
7+
8+
add_subdirectory(TransformOps)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
10+
#define MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// ArmNeon Vector Transform Operations
17+
//===----------------------------------------------------------------------===//
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc"
21+
22+
namespace mlir {
23+
class DialectRegistry;
24+
25+
namespace arm_neon {
26+
void registerTransformDialectExtension(DialectRegistry &registry);
27+
28+
} // namespace arm_neon
29+
} // namespace mlir
30+
31+
#endif // MLIR_DIALECT_ARM_NEON_VECTOR_TRANSFORMOPS_VECTORTRANSFORMOPS_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- ArmNeonTransformOps.td - Arm Neon transform ops------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef ARMNEON_TRANSFORM_OPS
9+
#define ARMNEON_TRANSFORM_OPS
10+
11+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
15+
def ApplyArmNeonLowerContractionPatternsOp
16+
: Op<Transform_Dialect, "apply_patterns.vector.arm_neon.lower_contraction",
17+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
18+
let description = [{
19+
Indicates that vector contraction-like operations should be lowered to
20+
finer-grained vector primitives using the ArmNeon dialect.
21+
}];
22+
23+
let assemblyFormat = "attr-dict";
24+
}
25+
26+
#endif // ARMNEON_TRANSFORM_OPS
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
2+
mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)
5+
6+
add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3535
#include "mlir/Dialect/AMX/Transforms.h"
3636
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
37+
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
3738
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
3839
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
3940
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
106107
transform::registerLoopExtension(registry);
107108
transform::registerPDLExtension(registry);
108109
vector::registerTransformDialectExtension(registry);
110+
arm_neon::registerTransformDialectExtension(registry);
109111

110112
// Translation extensions need to be registered by calling
111113
// `registerAllToLLVMIRTranslations` (see All.h).
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
10+
11+
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
12+
#include "mlir/Dialect/ArmNeon/Transforms.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
15+
using namespace mlir;
16+
17+
//===----------------------------------------------------------------------===//
18+
// Apply...PatternsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
void transform::ApplyArmNeonLowerContractionPatternsOp::populatePatterns(
22+
RewritePatternSet &patterns) {
23+
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
24+
}
25+
26+
//===----------------------------------------------------------------------===//
27+
// Transform op registration
28+
//===----------------------------------------------------------------------===//
29+
30+
namespace {
31+
class ArmNeonVectorTransformDialectExtension
32+
: public transform::TransformDialectExtension<
33+
ArmNeonVectorTransformDialectExtension> {
34+
public:
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
36+
ArmNeonVectorTransformDialectExtension)
37+
38+
ArmNeonVectorTransformDialectExtension() {
39+
declareGeneratedDialect<arm_neon::ArmNeonDialect>();
40+
registerTransformOps<
41+
#define GET_OP_LIST
42+
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
43+
>();
44+
}
45+
};
46+
} // namespace
47+
48+
#define GET_OP_CLASSES
49+
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
50+
51+
void mlir::arm_neon::registerTransformDialectExtension(
52+
DialectRegistry &registry) {
53+
registry.addExtensions<ArmNeonVectorTransformDialectExtension>();
54+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_dialect_library(MLIRArmNeonVectorTransformOps
2+
ArmNeonVectorTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps
6+
7+
DEPENDS
8+
MLIRArmNeonVectorTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRLLVMCommonConversion
13+
MLIRLLVMDialect
14+
MLIRVectorDialect
15+
MLIRTransformDialect
16+
MLIRArmNeonDialect
17+
MLIRArmNeonTransforms
18+
)

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s
22

33
// CHECK-LABEL: vector_arm_neon_mixed_types
44
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
@@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2
354354
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
355355
return %res : vector<1x2xi32>
356356
}
357+
358+
module attributes {transform.with_named_sequence} {
359+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
360+
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
361+
362+
transform.apply_patterns to %func {
363+
transform.apply_patterns.vector.arm_neon.lower_contraction
364+
} : !transform.op<"func.func">
365+
366+
transform.yield
367+
}
368+
}

mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt

Lines changed: 0 additions & 13 deletions
This file was deleted.

mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp

Lines changed: 0 additions & 60 deletions
This file was deleted.

mlir/test/lib/Dialect/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_subdirectory(Affine)
22
add_subdirectory(Arith)
3-
add_subdirectory(ArmNeon)
43
add_subdirectory(ArmSME)
54
add_subdirectory(Bufferization)
65
add_subdirectory(ControlFlow)

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS)
1717
MLIRTestFuncToLLVM
1818
MLIRAffineTransformsTestPasses
1919
MLIRArithTestPasses
20-
MLIRArmNeonTestPasses
2120
MLIRArmSMETestPasses
2221
MLIRBufferizationTestPasses
2322
MLIRControlFlowTestPasses

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass();
120120
void registerTestLoopFusion();
121121
void registerTestLoopMappingPass();
122122
void registerTestLoopUnrollingPass();
123-
void registerTestLowerToArmNeon();
124123
void registerTestLowerToArmSME();
125124
void registerTestLowerToLLVM();
126125
void registerTestMakeIsolatedFromAbovePass();
@@ -264,7 +263,6 @@ void registerTestPasses() {
264263
mlir::test::registerTestLoopFusion();
265264
mlir::test::registerTestLoopMappingPass();
266265
mlir::test::registerTestLoopUnrollingPass();
267-
mlir::test::registerTestLowerToArmNeon();
268266
mlir::test::registerTestLowerToArmSME();
269267
mlir::test::registerTestLowerToLLVM();
270268
mlir::test::registerTestMakeIsolatedFromAbovePass();

0 commit comments

Comments
 (0)