Skip to content

Commit e268afb

Browse files
[MLIR][ROCDL] Add dynamically legal ops to LowerGpuOpsToROCDLOpsPass
1 parent 0d5d355 commit e268afb

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2727
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2828
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2930
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3031
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
3132
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -290,6 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
290291
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
291292
*maybeChipset);
292293
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
294+
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
293295
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
294296
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
295297
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
@@ -332,7 +334,12 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
332334
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
333335
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
334336
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
335-
337+
// These ops are not legal for f64 type but are legal for all other types.
338+
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
339+
return any_of(op->getOperandTypes(), [](Type type) {
340+
return isa<FloatType>(type) && type.getIntOrFloatBitWidth() < 64;
341+
});
342+
});
336343
// TODO: Remove once we support replacing non-root ops.
337344
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
338345
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,60 @@ gpu.module @test_module {
131131

132132
// -----
133133

134+
gpu.module @test_module {
135+
// CHECK-LABEL: func @gpu_sqrt
136+
func.func @gpu_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
137+
%result32 = math.sqrt %arg_f32 : f32
138+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f32) -> f32
139+
%result64 = math.sqrt %arg_f64 : f64
140+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f64) -> f64
141+
func.return %result32, %result64 : f32, f64
142+
}
143+
}
144+
145+
// -----
146+
147+
gpu.module @test_module {
148+
// CHECK-LABEL: func @gpu_fabs
149+
func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
150+
%result32 = math.absf %arg_f32 : f32
151+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f32) -> f32
152+
%result64 = math.absf %arg_f64 : f64
153+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f64) -> f64
154+
func.return %result32, %result64 : f32, f64
155+
}
156+
}
157+
158+
// -----
159+
160+
gpu.module @test_module {
161+
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
162+
// CHECK-LABEL: func @gpu_exp
163+
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
164+
%result32 = math.exp %arg_f32 : f32
165+
// CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32
166+
%result64 = math.exp %arg_f64 : f64
167+
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
168+
func.return %result32, %result64 : f32, f64
169+
}
170+
}
171+
172+
// -----
173+
174+
gpu.module @test_module {
175+
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
176+
// CHECK-LABEL: func @gpu_log
177+
func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
178+
%result32 = math.log %arg_f32 : f32
179+
// CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32
180+
%result64 = math.log %arg_f64 : f64
181+
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
182+
func.return %result32, %result64 : f32, f64
183+
}
184+
}
185+
186+
// -----
187+
134188
gpu.module @test_module {
135189
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
136190
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6004,6 +6004,7 @@ cc_library(
60046004
":LLVMCommonConversion",
60056005
":LLVMDialect",
60066006
":MathDialect",
6007+
":MathToLLVM",
60076008
":MathToROCDL",
60086009
":MemRefDialect",
60096010
":MemRefToLLVM",

0 commit comments

Comments
 (0)