Skip to content

Commit

Permalink
improve docker packages, fix bugs, enable tests, enable FFT (pytorch#…
Browse files Browse the repository at this point in the history
…10893)

Summary:
* improve docker packages (install OpenBLAS to have at-compile-time LAPACK functionality w/ optimizations for both Intel and AMD CPUs)
* integrate rocFFT (i.e., enable Fourier functionality)
* fix bugs in ROCm caused by wrong warp size
* enable more test sets, skip the tests that don't work on ROCm yet
* don't disable asserts any longer in hipification
* small improvements
Pull Request resolved: pytorch#10893

Differential Revision: D9615053

Pulled By: ezyang

fbshipit-source-id: 864b4d27bf089421f7dfd8065e5017f9ea2f7b3b
  • Loading branch information
iotamudelta authored and facebook-github-bot committed Sep 2, 2018
1 parent abe8b33 commit 33c7cc1
Show file tree
Hide file tree
Showing 23 changed files with 794 additions and 238 deletions.
67 changes: 66 additions & 1 deletion aten/src/ATen/native/cuda/CuFFTPlanCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ class CuFFTConfig {
IntList output_sizes) {

// signal sizes
#ifdef __HIP_PLATFORM_HCC__
std::vector<int> signal_sizes(checked_signal_sizes.begin(),
checked_signal_sizes.end());
#else
std::vector<long long int> signal_sizes(checked_signal_sizes.begin(),
checked_signal_sizes.end());
#endif

// input batch size
long long int batch = input.size(0);
Expand Down Expand Up @@ -149,7 +154,11 @@ class CuFFTConfig {
// TODO: Figure out why windows fails to compile
// at::optional<std::vector<long long int>> inembed_opt = at::nullopt;
// Then move the following to a helper function.
#ifdef __HIP_PLATFORM_HCC__
std::vector<int> inembed(signal_ndim);
#else
std::vector<long long int> inembed(signal_ndim);
#endif
if (!clone_input) {
auto istrides = input.strides();
auto last_istride = istrides[signal_ndim];
Expand Down Expand Up @@ -192,6 +201,37 @@ class CuFFTConfig {
inembed.begin()); // begin of output
}

#ifdef __HIP_PLATFORM_HCC__

hipfftType exec_type;
if (input.type().scalarType() == ScalarType::Float) {
if (complex_input && complex_output) {
exec_type = HIPFFT_C2C;
} else if (complex_input && !complex_output) {
exec_type = HIPFFT_C2R;
} else if (!complex_input && complex_output) {
exec_type = HIPFFT_R2C;
} else {
throw std::runtime_error("hipFFT doesn't support r2r (float)");
}
} else if (input.type().scalarType() == ScalarType::Double) {
if (complex_input && complex_output) {
exec_type = HIPFFT_Z2Z;
} else if (complex_input && !complex_output) {
exec_type = HIPFFT_Z2D;
} else if (!complex_input && complex_output) {
exec_type = HIPFFT_D2Z;
} else {
throw std::runtime_error("hipFFT doesn't support r2r (double)");
}
} else {
std::ostringstream ss;
ss << "hipFFT doesn't support tensor of type: "
<< at::toString(input.type().scalarType());
throw std::runtime_error(ss.str());
}

#else
cudaDataType itype, otype, exec_type;
if (input.type().scalarType() == ScalarType::Float) {
itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
Expand All @@ -211,6 +251,7 @@ class CuFFTConfig {
<< at::toString(input.type().scalarType());
throw std::runtime_error(ss.str());
}
#endif

// create plan
auto raw_plan_ptr = new cufftHandle();
Expand All @@ -229,10 +270,17 @@ class CuFFTConfig {
// by assuming base_istride = base_ostride = 1.
//
// See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
#ifdef __HIP_PLATFORM_HCC__
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1,
exec_type, batch, &ws_size_t));
#else
CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
/* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
/* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
batch, &ws_size_t, exec_type));
#endif
} else {
// set idist (stride at batch dim)
// set base_istride (stride at innermost dim of signal)
Expand All @@ -254,6 +302,18 @@ class CuFFTConfig {
}

// set odist, onembed, base_ostride
#ifdef __HIP_PLATFORM_HCC__
int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim));
std::vector<int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1);
int base_ostride = 1;

int istride = base_istride;
int iidist = idist;
CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
inembed.data(), istride, iidist,
onembed.data(), base_ostride, odist,
exec_type, batch, &ws_size_t));
#else
long long int odist = at::prod_intlist(output_sizes.slice(1, signal_ndim));
std::vector<long long int> onembed(output_sizes.data() + 1, output_sizes.data() + signal_ndim + 1);
long long int base_ostride = 1;
Expand All @@ -262,11 +322,16 @@ class CuFFTConfig {
inembed.data(), base_istride, idist, itype,
onembed.data(), base_ostride, odist, otype,
batch, &ws_size_t, exec_type));
}
#endif
}
ws_size = static_cast<int64_t>(ws_size_t);
}

#ifdef __HIP_PLATFORM_HCC__
cufftHandle &plan() const { return *plan_ptr.get(); }
#else
const cufftHandle &plan() const { return *plan_ptr.get(); }
#endif

bool should_clone_input() const { return clone_input; }

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/CuFFTUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#ifndef __HIP_PLATFORM_HCC__
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif
case CUFFT_NOT_SUPPORTED:
return "CUFFT_NOT_SUPPORTED";
default:
Expand Down
37 changes: 37 additions & 0 deletions aten/src/ATen/native/cuda/SpectralOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,45 @@ static inline Tensor _run_cufft(
CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr()));

// run
#ifdef __HIP_PLATFORM_HCC__
if (input.type().scalarType() == ScalarType::Float) {
if (complex_input && complex_output) {
CUFFT_CHECK(hipfftExecC2C(plan, static_cast<hipfftComplex*>(input.data_ptr()),
static_cast<hipfftComplex*>(output.data_ptr()),
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD));
} else if (complex_input && !complex_output) {
CUFFT_CHECK(hipfftExecC2R(plan, static_cast<hipfftComplex*>(input.data_ptr()),
static_cast<hipfftReal*>(output.data_ptr())));
} else if (!complex_input && complex_output) {
CUFFT_CHECK(hipfftExecR2C(plan, static_cast<hipfftReal*>(input.data_ptr()),
static_cast<hipfftComplex*>(output.data_ptr())));
} else {
throw std::runtime_error("hipFFT doesn't support r2r (float)");
}
} else if (input.type().scalarType() == ScalarType::Double) {
if (complex_input && complex_output) {
CUFFT_CHECK(hipfftExecZ2Z(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()),
static_cast<hipfftDoubleComplex*>(output.data_ptr()),
inverse ? HIPFFT_BACKWARD : HIPFFT_FORWARD));
} else if (complex_input && !complex_output) {
CUFFT_CHECK(hipfftExecZ2D(plan, static_cast<hipfftDoubleComplex*>(input.data_ptr()),
static_cast<hipfftDoubleReal*>(output.data_ptr())));
} else if (!complex_input && complex_output) {
CUFFT_CHECK(hipfftExecD2Z(plan, static_cast<hipfftDoubleReal*>(input.data_ptr()),
static_cast<hipfftDoubleComplex*>(output.data_ptr())));
} else {
throw std::runtime_error("hipFFT doesn't support r2r (double)");
}
} else {
std::ostringstream ss;
ss << "hipFFT doesn't support tensor of type: "
<< at::toString(input.type().scalarType());
throw std::runtime_error(ss.str());
}
#else
CUFFT_CHECK(cufftXtExec(plan, input.data_ptr(), output.data_ptr(),
inverse ? CUFFT_INVERSE : CUFFT_FORWARD));
#endif

// rescale if needed by normalized flag or inverse transform
auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
Expand Down
2 changes: 2 additions & 0 deletions aten/src/THC/THCAtomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ static inline __device__ void atomicAdd(double *address, double val) {
} while (assumed != old);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000) || defined(__HIP_PLATFORM_HCC__)
#if defined(__HIP_PLATFORM_HCC__) && __hcc_workweek__ < 18312
// This needs to be defined for the host side pass
static inline __device__ void atomicAdd(double *address, double val) { }
#endif
#endif

#endif // THC_ATOMICS_INC
14 changes: 11 additions & 3 deletions aten/src/THC/THCScanUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
#include "THCAsmUtils.cuh"
#include "THCDeviceUtils.cuh"

#if defined(__HIP_PLATFORM_HCC__)
#define SCAN_UTILS_WARP_SIZE 64
#else
#define SCAN_UTILS_WARP_SIZE 32
#endif

// Collection of in-kernel scan / prefix sum utilities

// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
Expand Down Expand Up @@ -157,7 +163,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
T index = __popc(getLaneMaskLe() & vote);
T carry = __popc(vote);

int warp = threadIdx.x / 32;
int warp = threadIdx.x / SCAN_UTILS_WARP_SIZE;

// Per each warp, write out a value
if (getLaneId() == 0) {
Expand All @@ -170,7 +176,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti
// warp shuffle scan for CC 3.0+
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) {
for (int i = 0; i < blockDim.x / SCAN_UTILS_WARP_SIZE; ++i) {
T v = smem[i];
smem[i] = binop(smem[i], current);
current = binop(current, v);
Expand Down Expand Up @@ -201,11 +207,13 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;

// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / 32) - 1];
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];

if (KillWARDependency) {
__syncthreads();
}
}

#undef SCAN_UTILS_WARP_SIZE

#endif // THC_SCAN_UTILS_INC
8 changes: 8 additions & 0 deletions aten/src/THC/THCTensorTopK.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ __device__ DataType findPattern(DataType* smem,
IndexType withinSliceStride,
BitDataType desired,
BitDataType desiredMask) {
#ifdef __HIP_PLATFORM_HCC__
if (threadIdx.x < 64) {
#else
if (threadIdx.x < 32) {
#endif
smem[threadIdx.x] = ScalarConvert<int, DataType>::to(0);
}
__syncthreads();
Expand Down Expand Up @@ -366,7 +370,11 @@ __global__ void gatherTopK(TensorInfo<T, IndexType> input,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
#ifdef __HIP_PLATFORM_HCC__
__shared__ int smem[64];
#else
__shared__ int smem[32]; // one per each warp, up to warp limit
#endif

IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
Expand Down
26 changes: 18 additions & 8 deletions aten/src/THC/generic/THCTensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ THC_API void THCTensor_(topk)(THCState* state,
THCTensor_(resize)(state, topK, topKSize, {});
THCudaLongTensor_resize(state, indices, topKSize, {});

// static_cast is required to ensure that the correct type (INDEX_T)
// is provided to the kernel for the arguments.

#define RUN_K(INDEX_T, DIM, DIR) \
gatherTopK<real, INDEX_T, DIM, DIR> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
inputInfo, \
sliceSize, \
k, \
inputSlices, \
static_cast<INDEX_T>(sliceSize), \
static_cast<INDEX_T>(k), \
static_cast<INDEX_T>(inputSlices), \
/* The actual dimension that the k-selection is running in */ \
/* may have changed from collapseDims() */ \
inputInfo.strides[collapseInputDim], \
static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]), \
topKInfo, \
topKSlices, \
topKInfo.strides[collapseTopKDim], \
static_cast<INDEX_T>(topKSlices), \
static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]), \
indicesInfo, \
indicesInfo.strides[collapseIndicesDim])
static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim]))

#define RUN_DIR(INDEX_T, DIM) \
if (dir) { \
Expand All @@ -63,6 +66,12 @@ THC_API void THCTensor_(topk)(THCState* state,
RUN_DIR(INDEX_T, -1); \
}

#ifdef __HIP_PLATFORM_HCC__
#define TOPK_WARP_SIZE 64
#else
#define TOPK_WARP_SIZE 32
#endif

#define RUN_T(INDEX_T) \
TensorInfo<real, INDEX_T> inputInfo = \
getTensorInfo<real, THCTensor, INDEX_T>(state, input); \
Expand Down Expand Up @@ -96,7 +105,7 @@ THC_API void THCTensor_(topk)(THCState* state,
THError("Slice to sort is too large"); \
} \
\
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) 32), (int64_t) 1024)); \
dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) TOPK_WARP_SIZE), (int64_t) 1024)); \
\
/* This is used as a template parameter to calculate indices. */ \
/* We only specialize it if all collapsed dim sizes are the */ \
Expand Down Expand Up @@ -124,6 +133,7 @@ THC_API void THCTensor_(topk)(THCState* state,
#undef RUN_DIM
#undef RUN_DIR
#undef RUN_K
#undef TOPK_WARP_SIZE

// Sort the results if the user wants them sorted, since our
// selection routine does not ensure sorting
Expand Down
1 change: 1 addition & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ endif()
if(USE_ROCM)
include_directories(SYSTEM ${HIP_PATH}/include)
include_directories(SYSTEM ${ROCBLAS_PATH}/include)
include_directories(SYSTEM ${ROCFFT_PATH}/include)
include_directories(SYSTEM ${HIPSPARSE_PATH}/include)
include_directories(SYSTEM ${HIPRAND_PATH}/include)
include_directories(SYSTEM ${ROCRAND_PATH}/include)
Expand Down
9 changes: 9 additions & 0 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ ELSE()
SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
ENDIF()

# ROCFFT_PATH
IF(NOT DEFINED ENV{ROCFFT_PATH})
SET(ROCBLAS_PATH ${ROCM_PATH}/rocfft)
ELSE()
SET(ROCFFT_PATH $ENV{ROCFFT_PATH})
ENDIF()

# HIPSPARSE_PATH
IF(NOT DEFINED ENV{HIPSPARSE_PATH})
SET(HIPSPARSE_PATH ${ROCM_PATH}/hcsparse)
Expand Down Expand Up @@ -106,11 +113,13 @@ IF(HIP_FOUND)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)

find_package(rocrand REQUIRED)
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(rocfft REQUIRED)
find_package(miopen REQUIRED)
#find_package(hipsparse REQUIRED)

Expand Down
11 changes: 11 additions & 0 deletions docker/caffe2/jenkins/common/install_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set -ex
install_ubuntu() {
apt-get update
apt-get install -y wget
apt-get install -y libopenblas-dev

DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/debian
# Add rocm repository
Expand Down Expand Up @@ -63,6 +64,15 @@ install_rocrand() {
dpkg -i /opt/rocm/debians/rocrand.deb
}

# Install rocSPARSE/hipSPARSE that will be released soon - can co-exist w/ hcSPARSE which will be removed soon
install_hipsparse() {
mkdir -p /opt/rocm/debians
curl https://s3.amazonaws.com/ossci-linux/rocsparse-0.1.1.0.deb -o /opt/rocm/debians/rocsparse.deb
curl https://s3.amazonaws.com/ossci-linux/hipsparse-0.1.1.0.deb -o /opt/rocm/debians/hipsparse.deb
dpkg -i /opt/rocm/debians/rocsparse.deb
dpkg -i /opt/rocm/debians/hipsparse.deb
}

# Install Python packages depending on the base OS
if [ -f /etc/lsb-release ]; then
install_ubuntu
Expand All @@ -76,3 +86,4 @@ fi
install_hip_thrust
install_rocrand
install_hcsparse
install_hipsparse
Loading

0 comments on commit 33c7cc1

Please sign in to comment.