From b250faccd357fb82c6917be4713c9f2816f5d8d5 Mon Sep 17 00:00:00 2001 From: "Gregory Meyer (gregjm)" Date: Tue, 9 May 2023 13:33:01 -0700 Subject: [PATCH] Make operator() const-correct and add missing static functions. (#936) * Make operator() const-correct and add missing static functions. Currently, `*Converter::operator()` requires a mutable object to invoke, and there are missing `static result_type convert(source_type const & source)` overloads for certain partial specializations of `*Converter` objects. This commit makes `operator()` const-correct and adds missing function overloads where appropriate. * minor changes * format --------- Co-authored-by: Haicheng Wu --- include/cutlass/numeric_conversion.h | 182 +++++++++++++++------------ 1 file changed, 99 insertions(+), 83 deletions(-) diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 68c259ccfa..0ba84c74e7 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -75,13 +75,13 @@ struct NumericConverter { static FloatRoundStyle const round_style = Round; CUTLASS_HOST_DEVICE - static result_type convert(source_type const & s) { + static result_type convert(source_type const & s) { return static_cast(s); } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -107,7 +107,7 @@ struct NumericConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -126,7 +126,7 @@ struct NumericConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -145,7 +145,7 @@ struct NumericConverter { return (result_type)std::nearbyint(s); } - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -162,7 +162,7 @@ struct NumericConverter { return (result_type)std::nearbyint(s); } - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -192,7 +192,7 @@ struct NumericConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -214,7 +214,7 @@ struct NumericConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -241,7 +241,7 @@ struct NumericConverter { return static_cast(intermediate); } - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -266,7 +266,7 @@ struct NumericConverter { return static_cast(intermediate); } - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -290,7 +290,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -318,7 +318,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -340,7 +340,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -409,7 +409,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -435,7 +435,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -452,7 +452,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -482,7 +482,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -529,7 +529,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -579,7 +579,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -596,7 +596,7 @@ struct NumericConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -682,7 +682,7 @@ struct NumericConverterFastF32 { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -715,7 +715,7 @@ struct NumericConverterClamp { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -732,11 +732,15 @@ struct NumericConverterClamp { using source_type = S; CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { - return static_cast(s); + static result_type convert(source_type const &source) { + return static_cast(source); } -}; + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -782,7 +786,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -804,20 +808,23 @@ struct NumericArrayConverter { "Unary Operator not supported."); CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { - if( platform::is_same::value ) - { - return s; - } else { - result_type result; - for (int i = 0; i < N; ++i) { - result[i] = conj(s[i]); - } - return result; + static result_type convert(source_type const &source) { + if (platform::is_same::value) { + return source; + } else { + result_type result; + for (int i = 0; i < N; ++i) { + result[i] = conj(source[i]); } + return result; + } } -}; + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -846,7 +853,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -918,7 +925,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -959,12 +966,11 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) @@ -989,7 +995,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1067,7 +1073,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1096,7 +1102,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1127,7 +1133,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1163,7 +1169,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1190,7 +1196,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1219,7 +1225,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1250,7 +1256,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1286,7 +1292,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1349,7 +1355,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1397,7 +1403,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1452,7 +1458,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1500,7 +1506,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1551,7 +1557,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1600,7 +1606,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1645,7 +1651,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1694,7 +1700,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1748,7 +1754,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1794,7 +1800,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1842,7 +1848,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1888,7 +1894,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1925,7 +1931,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1956,7 +1962,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -1986,8 +1992,13 @@ struct NumericArrayConverter { static FloatRoundStyle const round_style = Round; CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { - return s; + static result_type convert(source_type const &source) { + return source; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); } }; @@ -2004,8 +2015,13 @@ struct NumericArrayConverter { static FloatRoundStyle const round_style = Round; CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { - return s; + static result_type convert(source_type const &source) { + return source; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); } }; @@ -2063,7 +2079,7 @@ struct PackedNumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2165,7 +2181,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2206,7 +2222,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2242,7 +2258,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2277,7 +2293,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2313,7 +2329,7 @@ struct NumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2341,7 +2357,7 @@ struct FastNumericArrayConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { return convert(s); } + result_type operator()(source_type const &s) const { return convert(s); } }; /// Partial specialization for Array <= Array @@ -2365,7 +2381,7 @@ struct FastNumericArrayConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { return convert(s); } + result_type operator()(source_type const &s) const { return convert(s); } }; /// Partial specialization for Array <= Array @@ -2393,7 +2409,7 @@ struct FastNumericArrayConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { return convert(s); } + result_type operator()(source_type const &s) const { return convert(s); } }; /// Partial specialization for Array <= Array @@ -2425,7 +2441,7 @@ struct FastNumericArrayConverter { } CUTLASS_DEVICE - result_type operator()(source_type const &s) { return convert(s); } + result_type operator()(source_type const &s) const { return convert(s); } }; /////////////////////////////////////////////////////////////////////////////////////////////////