Skip to content

Commit

Permalink
Use platform:: instead of std::abs and std::conditional (NVIDIA#452)
Browse files Browse the repository at this point in the history
* Fixed template struct/class mismatch

* Use platform implementation instead of std::abs and std::conditional during nvrtc compilation

* Use platform implementation instead of std::abs and std::conditional during nvrtc compilation

* Revert absolute_value() usage
  • Loading branch information
kroburg authored Apr 25, 2022
1 parent 70f3ba5 commit 71def2f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
4 changes: 2 additions & 2 deletions include/cutlass/conv/conv2d_problem_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,12 @@ void strided_dgrad_starting_coords(
// function locals for remainder by fast divmod
int pad_h_rem_, pad_w_rem_;

// start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
// start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h;
stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h);
int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r));
stride_h_divmod.divmod(start_h, r_);

//start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
//start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w;
stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w);
int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s));
stride_w_divmod.divmod(start_w, s_);
Expand Down
18 changes: 17 additions & 1 deletion include/cutlass/platform/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
* (2) Re-implementations of STL functions and types:
* - C++ features that need the \p __device__ annotation. These are
* placed into the \p platform namespace.
* - \p abs
* - \p plus
* - \p less
* - \p greater
Expand Down Expand Up @@ -184,6 +185,22 @@
namespace cutlass {
namespace platform {

//-----------------------------------------------------------------------------
// Abs operations <algorithm>
//-----------------------------------------------------------------------------

#if defined(__CUDACC_RTC__)
/// std::abs
CUTLASS_HOST_DEVICE constexpr int abs(int a) {
return (a < 0) ? -a : a;
}
CUTLASS_HOST_DEVICE constexpr long long abs(long long a) {
return (a < 0) ? -a : a;
}
#else
using std::abs;
#endif

//-----------------------------------------------------------------------------
// Minimum/maximum operations <algorithm>
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -435,7 +452,6 @@ struct is_base_of
typename remove_cv<DerivedT>::type>::value) ||
(is_same<typename remove_cv<BaseT>::type,
typename remove_cv<DerivedT>::type>::value)> {};

#else

using std::is_same;
Expand Down

0 comments on commit 71def2f

Please sign in to comment.