@@ -342,6 +342,59 @@ LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::TestConfig(LongVector::Bin
342342 }
343343}
344344
345+ template <typename DataTypeT, typename LongVectorOpTypeT>
346+ LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::TestConfig(LongVector::TrigonometricOpType OpType)
347+ : OpTypeTraits(OpType) {
348+ IntrinsicString = " " ;
349+ BasicOpType = LongVector::BasicOpType_Unary;
350+
351+ // All trigonometric ops are floating point types.
352+ // These trig functions are defined to have a max absolute error of 0.0008
353+ // as per the D3D functional specs. An example with this spec for sin and
354+ // cos is available here:
355+ // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#22.10.20
356+ ValidationType = LongVector::ValidationType_Epsilon;
357+ if (std::is_same_v<DataTypeT, HLSLHalf_t>)
358+ Tolerance = 0 .0010f ;
359+ else if (std::is_same_v<DataTypeT, float >)
360+ Tolerance = 0 .0008f ;
361+ else
362+ VERIFY_FAIL (
363+ " Invalid type for trigonometric op. Expecting half or float." );
364+
365+ switch (OpType) {
366+ case LongVector::TrigonometricOpType_Acos:
367+ IntrinsicString = " acos" ;
368+ break ;
369+ case LongVector::TrigonometricOpType_Asin:
370+ IntrinsicString = " asin" ;
371+ break ;
372+ case LongVector::TrigonometricOpType_Atan:
373+ IntrinsicString = " atan" ;
374+ break ;
375+ case LongVector::TrigonometricOpType_Cos:
376+ IntrinsicString = " cos" ;
377+ break ;
378+ case LongVector::TrigonometricOpType_Cosh:
379+ IntrinsicString = " cosh" ;
380+ break ;
381+ case LongVector::TrigonometricOpType_Sin:
382+ IntrinsicString = " sin" ;
383+ break ;
384+ case LongVector::TrigonometricOpType_Sinh:
385+ IntrinsicString = " sinh" ;
386+ break ;
387+ case LongVector::TrigonometricOpType_Tan:
388+ IntrinsicString = " tan" ;
389+ break ;
390+ case LongVector::TrigonometricOpType_Tanh:
391+ IntrinsicString = " tanh" ;
392+ break ;
393+ default :
394+ VERIFY_FAIL (" Invalid TrigonometricOpType" );
395+ }
396+ }
397+
345398template <typename DataTypeT, typename LongVectorOpTypeT>
346399bool LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::hasFunctionDefinition() const {
347400 if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::UnaryOpType>) {
@@ -463,6 +516,13 @@ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedV
463516template <typename DataTypeT, typename LongVectorOpTypeT>
464517DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedValue(const DataTypeT &A) const {
465518
519+ if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::TrigonometricOpType>) {
520+ const auto OpType = static_cast <LongVector::TrigonometricOpType>(OpTypeTraits.OpType );
521+ // HLSLHalf_t is a struct. We need to call the constructor to get the
522+ // expected value.
523+ return computeExpectedValue (A, OpType);
524+ }
525+
466526 if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::UnaryOpType>) {
467527 const auto OpType = static_cast <LongVector::UnaryOpType>(OpTypeTraits.OpType );
468528 // HLSLHalf_t is a struct. We need to call the constructor to get the
@@ -477,6 +537,67 @@ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedV
477537 return DataTypeT ();
478538}
479539
540+ template <typename DataTypeT, typename LongVectorOpTypeT>
541+ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedValue(const DataTypeT &A,
542+ LongVector::TrigonometricOpType OpType) const {
543+ // The trig functions are only valid on floating point types. The constexpr in
544+ // this case is a relatively easy and clean way to prevent the compiler from
545+ // erroring out trying to resolve these for the non floating point types. We
546+ // won't use them in the first place.
547+ if constexpr (isFloatingPointType<DataTypeT>()) {
548+ switch (OpType) {
549+ case LongVector::TrigonometricOpType_Acos:
550+ return std::acos (A);
551+ case LongVector::TrigonometricOpType_Asin:
552+ return std::asin (A);
553+ case LongVector::TrigonometricOpType_Atan:
554+ return std::atan (A);
555+ case LongVector::TrigonometricOpType_Cos:
556+ return std::cos (A);
557+ case LongVector::TrigonometricOpType_Cosh:
558+ return std::cosh (A);
559+ case LongVector::TrigonometricOpType_Sin:
560+ return std::sin (A);
561+ case LongVector::TrigonometricOpType_Sinh:
562+ return std::sinh (A);
563+ case LongVector::TrigonometricOpType_Tan:
564+ return std::tan (A);
565+ case LongVector::TrigonometricOpType_Tanh:
566+ return std::tanh (A);
567+ default :
568+ LOG_ERROR_FMT_THROW (L" Unknown TrigonometricOpType: %d" ,
569+ OpTypeTraits.OpType );
570+ return DataTypeT ();
571+ }
572+ }
573+
574+ LOG_ERROR_FMT_THROW (L" ComputeExpectedValue(const DataTypeT &A, "
575+ L" LongVectorOpTypeT OpType) called on a "
576+ L" non-float type: %d" ,
577+ OpType);
578+
579+ return DataTypeT ();
580+ }
581+
582+ template <typename DataTypeT, typename LongVectorOpTypeT>
583+ std::vector<DataTypeT> LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::getInputArgsArray() const {
584+
585+ std::vector<DataTypeT> InputArgs;
586+
587+ std::wstring InputArgsArrayName = this ->InputArgsArrayName ;
588+
589+ if (InputArgsArrayName.empty ())
590+ VERIFY_FAIL (" No args array name set." );
591+
592+ if (std::is_same_v<DataTypeT, HLSLBool_t> && isClampOp ())
593+ VERIFY_FAIL (" Clamp is not supported for bools." );
594+ else
595+ return getInputValueSetByKey<DataTypeT>(InputArgsArrayName, false );
596+
597+ VERIFY_FAIL (" Invalid type for args array." );
598+ return std::vector<DataTypeT>();
599+ }
600+
480601template <typename DataTypeT, typename LongVectorOpTypeT>
481602std::string LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::getCompilerOptionsString(size_t VectorSize) const {
482603 std::stringstream CompilerOptions (" " );
0 commit comments