-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL][DXIL] Implement refract
intrinsic
#136026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
a1ccf10
to
dff5181
Compare
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: None (raoanag) Changes
Resolves #99153 Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff 12 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}
+def SPIRVRefract : Builtin {
+ let Spellings = ["__builtin_spirv_refract"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ Value *I = EmitScalarExpr(E->getArg(0));
+ Value *N = EmitScalarExpr(E->getArg(1));
+ Value *eta = EmitScalarExpr(E->getArg(2));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasFloatingRepresentation() &&
+ E->getArg(2)->getType()->hasFloatingRepresentation() &&
+ "refract operands must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ E->getArg(1)->getType()->isVectorType() &&
+ "refract I and N operands must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+ ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+ }
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+ T k = 1 - eta * eta * (1 - (N * I * N *I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+ return __builtin_spirv_refract(I, N, eta);
+#else
+ vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+ __detail::is_same<half, T>::value,
+ T> refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+ __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+ refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+ __detail::HLSL_FIXED_VECTOR<half, L> I,
+ __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+ __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ auto *VTyA = ArgTyA->getAs<VectorType>();
+ if (VTyA == nullptr) {
+ SemaRef.Diag(A.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyA
+ << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyB = B.get()->getType();
+ auto *VTyB = ArgTyB->getAs<VectorType>();
+ if (VTyB == nullptr) {
+ SemaRef.Diag(B.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyB
+ << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult C = TheCall->getArg(2);
+ QualType ArgTyC = C.get()->getType();
+ if (!ArgTyC->hasFloatingRepresentation()) {
+ SemaRef.Diag(C.get()->getBeginLoc(),
+ diag::err_builtin_invalid_arg_type)
+ << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+ << ArgTyC;
+ return true;
+ }
+
+ QualType RetTy = ArgTyA;
+ TheCall->setType(RetTy);
+ assert(RetTy == ArgTyA);
+ break;
+ }
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
+ SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT: [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT: [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT: [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT: [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT: [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK: if.else.i: ; preds = %entry
+// SPVCHECK-NEXT: [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT: [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT: ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]
|
@llvm/pr-subscribers-clang-codegen Author: None (raoanag) Changes
Resolves #99153 Patch is 53.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136026.diff 12 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 9f76d672cc7ce..c0f652b4f24e4 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}
+def SPIRVRefract : Builtin {
+ let Spellings = ["__builtin_spirv_refract"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 92e2c1c6da68f..5fedb9553699f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ Value *I = EmitScalarExpr(E->getArg(0));
+ Value *N = EmitScalarExpr(E->getArg(1));
+ Value *eta = EmitScalarExpr(E->getArg(2));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasFloatingRepresentation() &&
+ E->getArg(2)->getType()->hasFloatingRepresentation() &&
+ "refract operands must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ E->getArg(1)->getType()->isVectorType() &&
+ "refract I and N operands must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+ ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+ }
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 3a8a9b6fa2a45..9a320a78453ac 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+ T k = 1 - eta * eta * (1 - (N * I * N *I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+ return __builtin_spirv_refract(I, N, eta);
+#else
+ vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+ if(k < 0)
+ return 0;
+ else
+ return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 35ff80052cf43..bb5b770b4141a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+ __detail::is_same<half, T>::value,
+ T> refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+ __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+ refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+ __detail::HLSL_FIXED_VECTOR<half, L> I,
+ __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+ __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 7131514d53421..2ad2089323cc3 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -69,6 +69,49 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ auto *VTyA = ArgTyA->getAs<VectorType>();
+ if (VTyA == nullptr) {
+ SemaRef.Diag(A.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyA
+ << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult B = TheCall->getArg(1);
+ QualType ArgTyB = B.get()->getType();
+ auto *VTyB = ArgTyB->getAs<VectorType>();
+ if (VTyB == nullptr) {
+ SemaRef.Diag(B.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyB
+ << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+
+ ExprResult C = TheCall->getArg(2);
+ QualType ArgTyC = C.get()->getType();
+ if (!ArgTyC->hasFloatingRepresentation()) {
+ SemaRef.Diag(C.get()->getBeginLoc(),
+ diag::err_builtin_invalid_arg_type)
+ << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+ << ArgTyC;
+ return true;
+ }
+
+ QualType RetTy = ArgTyA;
+ TheCall->setType(RetTy);
+ assert(RetTy == ArgTyA);
+ break;
+ }
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
@@ -89,7 +132,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
+ SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..a2e160f17b582
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,356 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT: [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT: [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT: [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT: [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT: [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]]
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT: [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT: [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT: [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT: br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK: if.else.i: ; preds = %entry
+// SPVCHECK-NEXT: [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT: [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT: [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT: [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT: br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK: _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT: ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT: [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT: ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+ return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT: [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT: [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT: [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK: br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK: if.else.i: ; preds = %entry
+// CHECK-NEXT: [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT: [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT: [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT: [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK: br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK: _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT: ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofp...
[truncated]
|
refract
intrinsic
// | ||
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh( | ||
// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] | ||
// SPVCHECK-NEXT: [[ENTRY:.*:]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the SPIRV and CHECK lines are the same we can combine these by adding a third check that covers the cases where the runlines are the same. There should be examples of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenHLSL/builtins/any.hlsl
a good example?
It just seems to not have SPVCHECK
@@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) { | |||
#endif | |||
} | |||
|
|||
template <typename T> constexpr T refract_impl(T I, T N, T eta) { | |||
T k = 1 - eta * eta * (1 - (N * I * N *I)); | |||
if(k < 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we can simplify these conditionals to use a select instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the latest commit uses select, but the hlsl tests have not been updated, will decide next steps based on offline discuss
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,c,h -- clang/test/CodeGenSPIRV/Builtins/refract.c clang/test/SemaSPIRV/BuiltIns/refract-errors.c clang/lib/CodeGen/TargetBuiltins/SPIR.cpp clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h clang/lib/Headers/hlsl/hlsl_intrinsics.h clang/lib/Sema/SemaSPIRV.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp View the diff from clang-format here.diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index e9d376b7f..e076f4ded 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -72,8 +72,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
}
template <typename T> constexpr T refract_impl(T I, T N, T eta) {
- T k = 1 - eta * eta * (1 - (N * I * N *I));
- if(k < 0)
+ T k = 1 - eta * eta * (1 - (N * I * N * I));
+ if (k < 0)
return 0;
else
return (eta * I - (eta * N * I + sqrt(k)) * N);
@@ -85,7 +85,7 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
return __builtin_spirv_refract(I, N, eta);
#else
vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
- if(k < 0)
+ if (k < 0)
return 0;
else
return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 54fae9f3b..ccfd0b75a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -445,8 +445,8 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
-/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
-/// off a surface with the normal \a N.
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
@@ -473,7 +473,7 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
- refract(T I, T N, T eta) {
+refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index b10c6b767..0f8c8c660 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -117,8 +117,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->hasFloatingRepresentation()) {
- SemaRef.Diag(C.get()->getBeginLoc(),
- diag::err_builtin_invalid_arg_type)
+ SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
<< ArgTyC;
return true;
|
run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests need to be redone. Try not to use any automation tooling to generate your tests. its causing you to miss sublte but important implementation details.
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh( | ||
// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] { | ||
// SPVCHECK: [[ENTRY:.*:]] | ||
// SPVCHECK: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is wrong we should not be casting half to double.
@@ -26,6 +26,12 @@ def SPIRVReflect : Builtin { | |||
let Prototype = "void(...)"; | |||
} | |||
|
|||
def SPIRVRefract : Builtin { | |||
let Spellings = ["__builtin_spirv_refract"]; | |||
let Attributes = [NoThrow, Const]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add CustomTypeChecking
.
let Attributes = [NoThrow, Const]; | |
let Attributes = [NoThrow, Const, CustomTypeChecking]; |
#if (__has_builtin(__builtin_spirv_refract)) | ||
return __builtin_spirv_refract(I, N, Eta); | ||
#else | ||
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); | ||
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N); | ||
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our most recent pattern on these has been to do one function when possible instead of a scalar and vector implementation. I beleive this is a case where that might be possible. Can you experiment with the following?
In hlsl_detail.h
add
template <typename T>
struct is_vector {
static const bool value = false;
};
/*NOTE: (don't include this comment) what I am doing here is adding a specialization for vector<T, N>*/
template <typename T, int N>
struct is_vector<vector<T, N>> {
static const bool value = true;
};
Then line 82-87 changes to
#if (__has_builtin(__builtin_spirv_refract)) | |
return __builtin_spirv_refract(I, N, Eta); | |
#else | |
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); | |
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N); | |
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result); | |
#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>) | |
return __builtin_spirv_refract(I, N, Eta); | |
#else | |
T K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); | |
T Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N); | |
return select(K < 0, 0, Result); |
@@ -16,88 +16,90 @@ namespace clang { | |||
|
|||
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} | |||
|
|||
/// Checks if the first `NumArgsToCheck` arguments of a function call are of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@spall did a bunch of work to get semantic checking to be more consistent. She should review at least this file for this PR.
/// \return `true` if any of the arguments is not a vector type, `false` | ||
/// otherwise. | ||
|
||
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect most cases all the function arguments will be the same type. creae an overload of this function that does not take a NumArgsToCheck
and instead just calls this function:
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { | |
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall) { | |
return CheckVectorArgs(theCall, TheCall->getNumArgs()); | |
} | |
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { |
reflect
and distance
should use the one that doesn't require you to specify argument size.
@@ -1,4 +1,3 @@ | |||
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
undo this change.
#if (__has_builtin(__builtin_spirv_refract)) | ||
return __builtin_spirv_refract(I, N, Eta); | ||
#else | ||
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should really be storing dot(N, I)
somewhere so that we aren't repeating the computation 3 times. That should simplify the codegen for the -O0 cases.
|
||
float2 test_no_second_arg(float2 p0) { | ||
return __builtin_spirv_refract(p0); | ||
// expected-error@-1 {{too few arguments to function call, expected 3, have 1}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need one less than test
@@ -20,7 +20,7 @@ namespace clang { | |||
class SemaSPIRV : public SemaBase { | |||
public: | |||
SemaSPIRV(Sema &S); | |||
|
|||
bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a helper, we don't need to expose it in the class. No one needs to use this outside of spirv and if they did it wouldn't live in the semaSPIRV file it would live somewhere like SemaChecking.
Resolves #99153