Skip to content

Commit 12fff8d

Browse files
authored
[mlir][LLVMIR] Add support for tan intrinsic op (#125748)
This patch adds support for Tan trig. function intrinsic in LLVM dialect & adds missing import/export tests for Sin
1 parent 0815b0e commit 12fff8d

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
107107
}
108108

109109
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">;
110-
def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">;
111110
def LLVM_ExpOp : LLVM_UnaryIntrOpF<"exp">;
112111
def LLVM_Exp2Op : LLVM_UnaryIntrOpF<"exp2">;
113112
def LLVM_FAbsOp : LLVM_UnaryIntrOpF<"fabs">;
@@ -125,7 +124,6 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
125124
> {
126125
let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
127126
}
128-
def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
129127
def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">;
130128
def LLVM_RoundOp : LLVM_UnaryIntrOpF<"round">;
131129
def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">;
@@ -167,6 +165,11 @@ def LLVM_SMaxOp : LLVM_BinarySameArgsIntrOpI<"smax">;
167165
def LLVM_SMinOp : LLVM_BinarySameArgsIntrOpI<"smin">;
168166
def LLVM_UMaxOp : LLVM_BinarySameArgsIntrOpI<"umax">;
169167
def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">;
168+
169+
def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
170+
def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">;
171+
def LLVM_TanOp : LLVM_UnaryIntrOpF<"tan">;
172+
170173
def LLVM_SinhOp : LLVM_UnaryIntrOpF<"sinh">;
171174
def LLVM_CoshOp : LLVM_UnaryIntrOpF<"cosh">;
172175
def LLVM_TanhOp : LLVM_UnaryIntrOpF<"tanh">;

mlir/test/Target/LLVMIR/Import/intrinsic.ll

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,23 @@ define void @floor_test(float %0, <8 x float> %1) {
101101
%4 = call <8 x float> @llvm.floor.v8f32(<8 x float> %1)
102102
ret void
103103
}
104-
; CHECK-LABEL: llvm.func @cos_test
105-
define void @cos_test(float %0, <8 x float> %1) {
104+
; CHECK-LABEL: llvm.func @trig_test
105+
define void @trig_test(float %0, <8 x float> %1) {
106+
; CHECK: llvm.intr.sin(%{{.*}}) : (f32) -> f32
107+
%3 = call float @llvm.sin.f32(float %0)
108+
; CHECK: llvm.intr.sin(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
109+
%4 = call <8 x float> @llvm.sin.v8f32(<8 x float> %1)
110+
106111
; CHECK: llvm.intr.cos(%{{.*}}) : (f32) -> f32
107-
%3 = call float @llvm.cos.f32(float %0)
112+
%5 = call float @llvm.cos.f32(float %0)
108113
; CHECK: llvm.intr.cos(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
109-
%4 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1)
114+
%6 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1)
115+
116+
; CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
117+
%7 = call float @llvm.tan.f32(float %0)
118+
; CHECK: llvm.intr.tan(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
119+
%8 = call <8 x float> @llvm.tan.v8f32(<8 x float> %1)
120+
110121
ret void
111122
}
112123
; CHECK-LABEL: llvm.func @hyperbolic_trig_test

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,22 @@ llvm.func @floor_test(%arg0: f32, %arg1: vector<8xf32>) {
103103
llvm.return
104104
}
105105

106-
// CHECK-LABEL: @cos_test
107-
llvm.func @cos_test(%arg0: f32, %arg1: vector<8xf32>) {
106+
// CHECK-LABEL: @trig_test
107+
llvm.func @trig_test(%arg0: f32, %arg1: vector<8xf32>) {
108+
// CHECK: call float @llvm.sin.f32
109+
llvm.intr.sin(%arg0) : (f32) -> f32
110+
// CHECK: call <8 x float> @llvm.sin.v8f32
111+
llvm.intr.sin(%arg1) : (vector<8xf32>) -> vector<8xf32>
112+
108113
// CHECK: call float @llvm.cos.f32
109-
"llvm.intr.cos"(%arg0) : (f32) -> f32
114+
llvm.intr.cos(%arg0) : (f32) -> f32
110115
// CHECK: call <8 x float> @llvm.cos.v8f32
111-
"llvm.intr.cos"(%arg1) : (vector<8xf32>) -> vector<8xf32>
116+
llvm.intr.cos(%arg1) : (vector<8xf32>) -> vector<8xf32>
117+
118+
// CHECK: call float @llvm.tan.f32
119+
llvm.intr.tan(%arg0) : (f32) -> f32
120+
// CHECK: call <8 x float> @llvm.tan.v8f32
121+
llvm.intr.tan(%arg1) : (vector<8xf32>) -> vector<8xf32>
112122
llvm.return
113123
}
114124

0 commit comments

Comments
 (0)