Skip to content

[HLSL][DXIL] Implement refract intrinsic #136026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}

def SPIRVRefract : Builtin {
let Spellings = ["__builtin_spirv_refract"];
let Attributes = [NoThrow, Const];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add CustomTypeChecking.

Suggested change
let Attributes = [NoThrow, Const];
let Attributes = [NoThrow, Const, CustomTypeChecking];

let Prototype = "void(...)";
}

def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/SemaSPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace clang {
class SemaSPIRV : public SemaBase {
public:
SemaSPIRV(Sema &S);

bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a helper, we don't need to expose it in the class. No one needs to use this outside of spirv and if they did it wouldn't live in the semaSPIRV file it would live somewhere like SemaChecking.

bool CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
};
} // namespace clang
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_refract: {
Value *I = EmitScalarExpr(E->getArg(0));
Value *N = EmitScalarExpr(E->getArg(1));
Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->isFloatingType() &&
"refract operands must have a float representation");
assert(E->getArg(0)->getType()->isVectorType() &&
E->getArg(1)->getType()->isVectorType() &&
"refract I and N operands must be a vector");
return Builder.CreateIntrinsic(
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Expand Down
17 changes: 17 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
T K = 1 - Eta * Eta * (1 - (N * I * N * I));
T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N);
return select<T>(K < 0, static_cast<T>(0), Result);
}

template <typename T, int L>
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should really be storing dot(N, I) somewhere so that we aren't repeating the computation 3 times. That should simplify the codegen for the -O0 cases.

vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
Comment on lines +82 to +87
Copy link
Member

@farzonl farzonl May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our most recent pattern on these has been to do one function when possible instead of a scalar and vector implementation. I beleive this is a case where that might be possible. Can you experiment with the following?

In hlsl_detail.h add

template <typename T>
struct is_vector {
  static const bool value = false;
};

/*NOTE: (don't include this comment) what I am doing here is adding a specialization for vector<T, N>*/
template <typename T, int N>
struct is_vector<vector<T, N>> {
  static const bool value = true;
};

Then line 82-87 changes to

Suggested change
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>)
return __builtin_spirv_refract(I, N, Eta);
#else
T K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
T Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
return select(K < 0, 0, Result);

#endif
}

template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
Expand Down
59 changes: 59 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// refract builtin
//===----------------------------------------------------------------------===//

/// \fn T refract(T I, T N, T eta)
/// \brief Returns a refraction using an entering ray, \a I, a surface
/// normal, \a N and refraction index \a eta
/// \param I The entering ray.
/// \param N The surface normal.
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
/// using the refraction index, \a eta, for the direction of the entering ray,
/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
/// if k < 0.0 the result is 0.0
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
///
/// I and N must already be normalized in order to achieve the desired result.
///
/// I and N must be a scalar or vector whose component type is
/// floating-point.
///
/// eta must be a 16-bit or 32-bit floating-point scalar.
///
/// Result type, the type of I, and the type of N must all be the same type.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
__detail::HLSL_FIXED_VECTOR<half, L> I,
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
return __detail::refract_vec_impl(I, N, eta);
}

template <int L>
const inline __detail::HLSL_FIXED_VECTOR<float, L>
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
return __detail::refract_vec_impl(I, N, eta);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
Expand Down
114 changes: 58 additions & 56 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,88 +16,90 @@ namespace clang {

SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}

/// Checks if the first `NumArgsToCheck` arguments of a function call are of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spall did a bunch of work to get semantic checking to be more consistent. She should review at least this file for this PR.

/// vector type. If any of the arguments is not a vector type, it emits a
/// diagnostic error and returns `true`. Otherwise, it returns `false`.
///
/// \param TheCall The function call expression to check.
/// \param NumArgsToCheck The number of arguments to check for vector type.
/// \return `true` if any of the arguments is not a vector type, `false`
/// otherwise.

bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect most cases all the function arguments will be the same type. creae an overload of this function that does not take a NumArgsToCheck and instead just calls this function:

Suggested change
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall) {
return CheckVectorArgs(theCall, TheCall->getNumArgs());
}
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {

reflect and distance should use the one that doesn't require you to specify argument size.

for (unsigned i = 0; i < NumArgsToCheck; ++i) {
ExprResult Arg = TheCall->getArg(i);
QualType ArgTy = Arg.get()->getType();
auto *VTy = ArgTy->getAs<VectorType>();
if (VTy == nullptr) {
SemaRef.Diag(Arg.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTy
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
}
return false;
}

bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
switch (BuiltinID) {
case SPIRV::BI__builtin_spirv_distance: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
// Use the helper function to check both arguments
if (CheckVectorArgs(TheCall, 2))
return true;
}

QualType RetTy = VTyA->getElementType();
QualType RetTy =
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTy = ArgTyA->getAs<VectorType>();
if (VTy == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;

// Use the helper function to check the argument
if (CheckVectorArgs(TheCall, 1))
return true;
}
QualType RetTy = VTy->getElementType();

QualType RetTy =
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
case SPIRV::BI__builtin_spirv_refract: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
// Use the helper function to check the first two arguments
if (CheckVectorArgs(TheCall, 2))
return true;
}

ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->isFloatingType()) {
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
return true;
}

QualType RetTy = ArgTyA;
QualType RetTy = TheCall->getArg(0)->getType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

// Use the helper function to check both arguments
if (CheckVectorArgs(TheCall, 2))
return true;

QualType RetTy = TheCall->getArg(0)->getType();
TheCall->setType(RetTy);
break;
}
Expand Down
1 change: 0 additions & 1 deletion clang/test/CodeGenHLSL/builtins/reflect.hlsl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undo this change.

// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -O1 -o - | FileCheck %s
Expand Down
Loading