Skip to content

Commit

Permalink
AMD support (#1430)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: rraminen <rraminen@amd.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: okakarpa <okakarpa@amd.com>
Co-authored-by: rraminen <rraminen@amd.com>
Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: okakarpa <okakarpa@amd.com>
Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com>
  • Loading branch information
8 people authored Mar 3, 2022
1 parent f0304bd commit c3c8d5d
Show file tree
Hide file tree
Showing 44 changed files with 1,471 additions and 134 deletions.
41 changes: 39 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Build
name: unit-tests

on:
push:
Expand All @@ -14,7 +14,7 @@ on:
jobs:
# unit tests running on nvidia gpus
nv-torch12-p40:
runs-on: [self-hosted, nvidia, torch12]
runs-on: [self-hosted, nvidia, torch12, p40]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -102,6 +102,43 @@ jobs:
find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec pip install -r {} \;
TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed
# unit tests running on amd gpus
amd:
# The type of runner that the job will run on
runs-on: [self-hosted, amd]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Runs a single command using the runners shell
- name: environment
run: |
rocm-smi --showhw
which python
python --version
which hipcc
hipcc --version
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
sudo apt-get update
sudo apt-get install -y libaio-dev
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning]
python -c "from deepspeed.env_report import cli_main; cli_main()"
#ds_report
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
#TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
nv-lightning-v100:
runs-on: [self-hosted, nvidia, torch18, v100]

Expand Down
18 changes: 18 additions & 0 deletions csrc/includes/cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#include <mma.h>
#endif
#include <stdio.h>

int cublas_gemm_ex(cublasHandle_t handle,
Expand All @@ -19,7 +21,11 @@ int cublas_gemm_ex(cublasHandle_t handle,
const float* A,
const float* B,
float* C,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
#endif

int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
Expand All @@ -32,7 +38,11 @@ int cublas_gemm_ex(cublasHandle_t handle,
const __half* A,
const __half* B,
__half* C,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif

int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
Expand All @@ -49,7 +59,11 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
#endif

int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
Expand All @@ -66,4 +80,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
int stride_B,
int stride_C,
int batch,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
8 changes: 8 additions & 0 deletions csrc/includes/custom_cuda_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
#include <stdio.h>
#include <stdlib.h>

#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <curand_kernel.h>

#include "context.h"
Expand Down
12 changes: 12 additions & 0 deletions csrc/includes/feed_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ class FeedForward {
weights,
input_ptr,
out,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(config_.gemm_algos[0]));
#else
cublasGemmAlgo_t(config_.gemm_algos[0]));
#endif
}
void Backward(int bsz,
const T* out_grad,
Expand All @@ -68,7 +72,11 @@ class FeedForward {
input_ptr,
out_grad,
weights_grad,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(config_.gemm_algos[1]));
#else
cublasGemmAlgo_t(config_.gemm_algos[1]));
#endif

cublas_gemm_ex(_cublasHandle,
CUBLAS_OP_N,
Expand All @@ -81,7 +89,11 @@ class FeedForward {
weights,
out_grad,
inp_grad_out,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(config_.gemm_algos[2]));
#else
cublasGemmAlgo_t(config_.gemm_algos[2]));
#endif

launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
}
Expand Down
34 changes: 34 additions & 0 deletions csrc/includes/gemm_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#pragma once

#include <cuda_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <array>
#include <cstdio>
#include <cstdlib>
Expand Down Expand Up @@ -58,7 +60,11 @@ class GemmTest {
B,
A,
C,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

int algo_bw1 = Run(loops, [=](int algo) {
Expand All @@ -73,7 +79,11 @@ class GemmTest {
A,
C,
B,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

int algo_bw2 = Run(loops, [=](int algo) {
Expand All @@ -88,7 +98,11 @@ class GemmTest {
B,
C,
A,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
Expand All @@ -100,8 +114,12 @@ class GemmTest {
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;

#ifdef __HIP_PLATFORM_HCC__
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
Expand Down Expand Up @@ -186,7 +204,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

int algo_bw1 = Run(loops, [=](int algo) {
Expand Down Expand Up @@ -216,7 +238,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

int algo_bw2 = Run(loops, [=](int algo) {
Expand All @@ -243,7 +269,11 @@ class StridedGemmTest {
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
static_cast<rocblas_gemm_algo>(algo));
#else
static_cast<cublasGemmAlgo_t>(algo));
#endif
});

return std::array<int, 3>({algo_fw, algo_bw1, algo_bw2});
Expand All @@ -255,8 +285,12 @@ class StridedGemmTest {
float fast_latency = (std::numeric_limits<float>::max)();
int fast_algo = 0;

#ifdef __HIP_PLATFORM_HCC__
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
#else
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
#endif
algo++) {
int warm_up = 5;
for (int i = 0; i < warm_up; ++i) f(algo);
Expand Down
4 changes: 4 additions & 0 deletions csrc/includes/general_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
#include <stdio.h>
#include <stdlib.h>

#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <curand_kernel.h>

#include "context.h"
Expand Down
Loading

0 comments on commit c3c8d5d

Please sign in to comment.