Skip to content

Commit 8544cfc

Browse files
Address reviwer comments
1 parent 4213aa7 commit 8544cfc

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2727
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2828
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2930
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3031
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
3132
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -290,6 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
290291
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
291292
*maybeChipset);
292293
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
294+
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
293295
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
294296
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
295297
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
@@ -333,13 +335,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
333335
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
334336
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
335337
// These ops are legal for f32 type.
336-
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>(
337-
[](mlir::Operation *op) {
338-
return llvm::any_of(op->getOperandTypes(), [](Type type) {
339-
return llvm::isa<FloatType>(type) &&
340-
type.getIntOrFloatBitWidth() == 32;
341-
});
342-
});
338+
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
339+
return any_of(op->getOperandTypes(), [](Type type) {
340+
return isa<FloatType>(type) && type.getIntOrFloatBitWidth() == 32;
341+
});
342+
});
343343
// TODO: Remove once we support replacing non-root ops.
344344
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
345345
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,60 @@ gpu.module @test_module {
131131

132132
// -----
133133

134+
gpu.module @test_module {
135+
// CHECK-LABEL: func @gpu_sqrt
136+
func.func @gpu_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
137+
%result32 = math.sqrt %arg_f32 : f32
138+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f32) -> f32
139+
%result64 = math.sqrt %arg_f64 : f64
140+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f64) -> f64
141+
func.return %result32, %result64 : f32, f64
142+
}
143+
}
144+
145+
// -----
146+
147+
gpu.module @test_module {
148+
// CHECK-LABEL: func @gpu_fabs
149+
func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
150+
%result32 = math.absf %arg_f32 : f32
151+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f32) -> f32
152+
%result64 = math.absf %arg_f64 : f64
153+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f64) -> f64
154+
func.return %result32, %result64 : f32, f64
155+
}
156+
}
157+
158+
// -----
159+
160+
gpu.module @test_module {
161+
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
162+
// CHECK-LABEL: func @gpu_exp
163+
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
164+
%result32 = math.exp %arg_f32 : f32
165+
// CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32
166+
%result64 = math.exp %arg_f64 : f64
167+
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
168+
func.return %result32, %result64 : f32, f64
169+
}
170+
}
171+
172+
// -----
173+
174+
gpu.module @test_module {
175+
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
176+
// CHECK-LABEL: func @gpu_log
177+
func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
178+
%result32 = math.log %arg_f32 : f32
179+
// CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32
180+
%result64 = math.log %arg_f64 : f64
181+
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
182+
func.return %result32, %result64 : f32, f64
183+
}
184+
}
185+
186+
// -----
187+
134188
gpu.module @test_module {
135189
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
136190
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64

0 commit comments

Comments
 (0)