-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL][DXIL] Implement refract
intrinsic
#136026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6ce233a
83d69dd
83c4f5a
1853e3d
dff5181
3d87d6d
e3c1b0a
d8b079c
d1e1fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ namespace clang { | |
class SemaSPIRV : public SemaBase { | ||
public: | ||
SemaSPIRV(Sema &S); | ||
|
||
bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a helper, we don't need to expose it in the class. No one needs to use this outside of spirv and if they did it wouldn't live in the semaSPIRV file it would live somewhere like SemaChecking. |
||
bool CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); | ||
}; | ||
} // namespace clang | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -71,6 +71,23 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) { | |||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T> constexpr T refract_impl(T I, T N, T Eta) { | ||||||||||||||||||||||||||
T K = 1 - Eta * Eta * (1 - (N * I * N * I)); | ||||||||||||||||||||||||||
T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N); | ||||||||||||||||||||||||||
return select<T>(K < 0, static_cast<T>(0), Result); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T, int L> | ||||||||||||||||||||||||||
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) { | ||||||||||||||||||||||||||
#if (__has_builtin(__builtin_spirv_refract)) | ||||||||||||||||||||||||||
return __builtin_spirv_refract(I, N, Eta); | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should really be storing |
||||||||||||||||||||||||||
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N); | ||||||||||||||||||||||||||
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result); | ||||||||||||||||||||||||||
Comment on lines
+82
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our most recent pattern on these has been to do one function when possible instead of a scalar and vector implementation. I beleive this is a case where that might be possible. Can you experiment with the following? In template <typename T>
struct is_vector {
static const bool value = false;
};
/*NOTE: (don't include this comment) what I am doing here is adding a specialization for vector<T, N>*/
template <typename T, int N>
struct is_vector<vector<T, N>> {
static const bool value = true;
}; Then line 82-87 changes to
Suggested change
|
||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T> constexpr T fmod_impl(T X, T Y) { | ||||||||||||||||||||||||||
#if !defined(__DIRECTX__) | ||||||||||||||||||||||||||
return __builtin_elementwise_fmod(X, Y); | ||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,88 +16,90 @@ namespace clang { | |||||||||||
|
||||||||||||
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} | ||||||||||||
|
||||||||||||
/// Checks if the first `NumArgsToCheck` arguments of a function call are of | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @spall did a bunch of work to get semantic checking to be more consistent. She should review at least this file for this PR. |
||||||||||||
/// vector type. If any of the arguments is not a vector type, it emits a | ||||||||||||
/// diagnostic error and returns `true`. Otherwise, it returns `false`. | ||||||||||||
/// | ||||||||||||
/// \param TheCall The function call expression to check. | ||||||||||||
/// \param NumArgsToCheck The number of arguments to check for vector type. | ||||||||||||
/// \return `true` if any of the arguments is not a vector type, `false` | ||||||||||||
/// otherwise. | ||||||||||||
|
||||||||||||
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect most cases all the function arguments will be the same type. creae an overload of this function that does not take a
Suggested change
|
||||||||||||
for (unsigned i = 0; i < NumArgsToCheck; ++i) { | ||||||||||||
ExprResult Arg = TheCall->getArg(i); | ||||||||||||
QualType ArgTy = Arg.get()->getType(); | ||||||||||||
auto *VTy = ArgTy->getAs<VectorType>(); | ||||||||||||
if (VTy == nullptr) { | ||||||||||||
SemaRef.Diag(Arg.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTy | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
} | ||||||||||||
return false; | ||||||||||||
} | ||||||||||||
|
||||||||||||
bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, | ||||||||||||
CallExpr *TheCall) { | ||||||||||||
switch (BuiltinID) { | ||||||||||||
case SPIRV::BI__builtin_spirv_distance: { | ||||||||||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||||||||||
return true; | ||||||||||||
|
||||||||||||
ExprResult A = TheCall->getArg(0); | ||||||||||||
QualType ArgTyA = A.get()->getType(); | ||||||||||||
auto *VTyA = ArgTyA->getAs<VectorType>(); | ||||||||||||
if (VTyA == nullptr) { | ||||||||||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTyA | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
|
||||||||||||
ExprResult B = TheCall->getArg(1); | ||||||||||||
QualType ArgTyB = B.get()->getType(); | ||||||||||||
auto *VTyB = ArgTyB->getAs<VectorType>(); | ||||||||||||
if (VTyB == nullptr) { | ||||||||||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTyB | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
// Use the helper function to check both arguments | ||||||||||||
if (CheckVectorArgs(TheCall, 2)) | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
|
||||||||||||
QualType RetTy = VTyA->getElementType(); | ||||||||||||
QualType RetTy = | ||||||||||||
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); | ||||||||||||
TheCall->setType(RetTy); | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
case SPIRV::BI__builtin_spirv_length: { | ||||||||||||
if (SemaRef.checkArgCount(TheCall, 1)) | ||||||||||||
return true; | ||||||||||||
ExprResult A = TheCall->getArg(0); | ||||||||||||
QualType ArgTyA = A.get()->getType(); | ||||||||||||
auto *VTy = ArgTyA->getAs<VectorType>(); | ||||||||||||
if (VTy == nullptr) { | ||||||||||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTyA | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
|
||||||||||||
// Use the helper function to check the argument | ||||||||||||
if (CheckVectorArgs(TheCall, 1)) | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
QualType RetTy = VTy->getElementType(); | ||||||||||||
|
||||||||||||
QualType RetTy = | ||||||||||||
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); | ||||||||||||
TheCall->setType(RetTy); | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
case SPIRV::BI__builtin_spirv_reflect: { | ||||||||||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||||||||||
case SPIRV::BI__builtin_spirv_refract: { | ||||||||||||
if (SemaRef.checkArgCount(TheCall, 3)) | ||||||||||||
return true; | ||||||||||||
|
||||||||||||
ExprResult A = TheCall->getArg(0); | ||||||||||||
QualType ArgTyA = A.get()->getType(); | ||||||||||||
auto *VTyA = ArgTyA->getAs<VectorType>(); | ||||||||||||
if (VTyA == nullptr) { | ||||||||||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTyA | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
// Use the helper function to check the first two arguments | ||||||||||||
if (CheckVectorArgs(TheCall, 2)) | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
|
||||||||||||
ExprResult B = TheCall->getArg(1); | ||||||||||||
QualType ArgTyB = B.get()->getType(); | ||||||||||||
auto *VTyB = ArgTyB->getAs<VectorType>(); | ||||||||||||
if (VTyB == nullptr) { | ||||||||||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||||||||||
diag::err_typecheck_convert_incompatible) | ||||||||||||
<< ArgTyB | ||||||||||||
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 | ||||||||||||
<< 0 << 0; | ||||||||||||
ExprResult C = TheCall->getArg(2); | ||||||||||||
QualType ArgTyC = C.get()->getType(); | ||||||||||||
if (!ArgTyC->isFloatingType()) { | ||||||||||||
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) | ||||||||||||
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC; | ||||||||||||
return true; | ||||||||||||
} | ||||||||||||
|
||||||||||||
QualType RetTy = ArgTyA; | ||||||||||||
QualType RetTy = TheCall->getArg(0)->getType(); | ||||||||||||
TheCall->setType(RetTy); | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
case SPIRV::BI__builtin_spirv_reflect: { | ||||||||||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||||||||||
return true; | ||||||||||||
|
||||||||||||
// Use the helper function to check both arguments | ||||||||||||
if (CheckVectorArgs(TheCall, 2)) | ||||||||||||
return true; | ||||||||||||
|
||||||||||||
QualType RetTy = TheCall->getArg(0)->getType(); | ||||||||||||
TheCall->setType(RetTy); | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. undo this change. |
||
// RUN: %clang_cc1 -finclude-default-header -triple \ | ||
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ | ||
// RUN: -emit-llvm -O1 -o - | FileCheck %s | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add
CustomTypeChecking
.