Skip to content

[MLIR] Add apply_patterns.vector.arm_neon.contraction_to_i8mm TD Op #140251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ add_mlir_doc(ArmNeon ArmNeon Dialects/ -gen-dialect-doc -dialect=arm_neon)
set(LLVM_TARGET_DEFINITIONS ArmNeon.td)
mlir_tablegen(ArmNeonConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmNeonConversionsIncGen)

add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- ArmNeonVectorTransformOps.h - Vector transform ops -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H
#define MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// ArmNeon Vector Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace arm_neon {
void registerTransformDialectExtension(DialectRegistry &registry);

} // namespace arm_neon
} // namespace mlir

#endif // MLIR_DIALECT_ARM_NEON_TRANSFORMOPS_VECTORTRANSFORMOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- ArmNeonVectorTransformOps.td - Arm Neon TD ops ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ARM_NEON_VECTOR_TRANSFORM_OPS
#define ARM_NEON_VECTOR_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"

def ApplyArmNeonContractionToI8MMPatternsOp
: Op<Transform_Dialect,
"apply_patterns.vector.arm_neon.contraction_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector.contract operations should be lowered to
finer-grained vector primitives from the ArmNeon dialect.
}];

let assemblyFormat = "attr-dict";
}

#endif // ARM_NEON_VECTOR_TRANSFORM_OPS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS ArmNeonVectorTransformOps.td)
mlir_tablegen(ArmNeonVectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(ArmNeonVectorTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRArmNeonVectorTransformOpsIncGen)

add_mlir_doc(ArmNeonVectorTransformOps ArmNeonVectorTransformOps Dialects/ -gen-op-doc)
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
Expand Down Expand Up @@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);

// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmNeon/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- ArmNeonVectorTransformOps.cpp - Implementation transform ops -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"

#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//

void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class ArmNeonVectorTransformDialectExtension
: public transform::TransformDialectExtension<
ArmNeonVectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
ArmNeonVectorTransformDialectExtension)

ArmNeonVectorTransformDialectExtension() {
declareGeneratedDialect<arm_neon::ArmNeonDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp.inc"

void mlir::arm_neon::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<ArmNeonVectorTransformDialectExtension>();
}
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/ArmNeon/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIRArmNeonVectorTransformOps
ArmNeonVectorTransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon/TransformOps

DEPENDS
MLIRArmNeonVectorTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRTransformDialect
MLIRArmNeonDialect
MLIRArmNeonTransforms
)
14 changes: 13 additions & 1 deletion mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s

// CHECK-LABEL: vector_arm_neon_mixed_types
// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32>
Expand Down Expand Up @@ -354,3 +354,15 @@ func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32>
return %res : vector<1x2xi32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">

transform.apply_patterns to %func {
transform.apply_patterns.vector.arm_neon.contraction_to_i8mm
} : !transform.op<"func.func">

transform.yield
}
}
13 changes: 0 additions & 13 deletions mlir/test/lib/Dialect/ArmNeon/CMakeLists.txt

This file was deleted.

60 changes: 0 additions & 60 deletions mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp

This file was deleted.

1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSME)
add_subdirectory(Bufferization)
add_subdirectory(ControlFlow)
Expand Down
1 change: 0 additions & 1 deletion mlir/tools/mlir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
MLIRArmNeonTestPasses
MLIRArmSMETestPasses
MLIRBufferizationTestPasses
MLIRControlFlowTestPasses
Expand Down
2 changes: 0 additions & 2 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
void registerTestLowerToArmNeon();
void registerTestLowerToArmSME();
void registerTestLowerToLLVM();
void registerTestMakeIsolatedFromAbovePass();
Expand Down Expand Up @@ -264,7 +263,6 @@ void registerTestPasses() {
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
mlir::test::registerTestLowerToArmNeon();
mlir::test::registerTestLowerToArmSME();
mlir::test::registerTestLowerToLLVM();
mlir::test::registerTestMakeIsolatedFromAbovePass();
Expand Down