Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5b2550b
modify cmake file according to rocm doc
Ldpe2G Jul 7, 2021
6920a6f
rm useless
Ldpe2G Jul 7, 2021
a512dce
fix compile error
Ldpe2G Jul 7, 2021
70cfe97
fix
Ldpe2G Jul 7, 2021
e623261
add
Ldpe2G Jul 7, 2021
306deee
move device register to cpp
Ldpe2G Jul 7, 2021
d4b9be9
add BUILD_ROCM option
Ldpe2G Jul 7, 2021
f10d502
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 8, 2021
183d6b9
add rocm utils and some rocm kernels registeration
Ldpe2G Jul 9, 2021
ac5df12
successfully run eager module test, add, matmul and generator
Ldpe2G Jul 12, 2021
5bac3a5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 12, 2021
4e7016f
successfully run lazy op test: add, matmul
Ldpe2G Jul 12, 2021
549b298
fix eager test_add
Ldpe2G Jul 12, 2021
5dc5c89
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 13, 2021
71f6866
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 15, 2021
85c6375
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 15, 2021
fd70f5e
support softmax on rocm
Ldpe2G Jul 16, 2021
ddb4578
successfully run bert model single gpu test
Ldpe2G Jul 19, 2021
13290db
support rccl, successfully run bert model with 2 gpu
Ldpe2G Jul 20, 2021
3b3c02b
support dropout
Ldpe2G Jul 20, 2021
77f829e
fix slice
Ldpe2G Jul 20, 2021
95f54f8
merge master
Ldpe2G Jul 21, 2021
e43d21f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 22, 2021
cd458f8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 23, 2021
575f97a
merge master
Ldpe2G Jul 26, 2021
860b34f
merge master
Ldpe2G Jul 26, 2021
ec27878
change with_rocm to with_hip
Ldpe2G Jul 26, 2021
35b2466
fix
Ldpe2G Jul 26, 2021
f75d9c6
remove dupliced codes
Ldpe2G Jul 26, 2021
26e1abf
merge master
Ldpe2G Jul 27, 2021
1323695
rename rocm_util to hip_util
Ldpe2G Jul 27, 2021
cb06bc5
rename the rocm_ prefix of rocm_device_context and rocm_stream_handle…
Ldpe2G Jul 27, 2021
678ad59
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Jul 28, 2021
68f0f6a
migrate rocm namespace to hip
Ldpe2G Jul 29, 2021
d4ed37d
refactor codes
Ldpe2G Jul 29, 2021
d582265
refactor codes under core/[kernel|ndarray|vm]
Ldpe2G Jul 30, 2021
187e536
refactor codes under user/kernels
Ldpe2G Jul 30, 2021
b6df591
refine
Ldpe2G Jul 30, 2021
f12c5aa
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Aug 1, 2021
c1fdc24
add dim_scatter hip implementation
Ldpe2G Aug 1, 2021
af7e2cd
support gpt
Ldpe2G Sep 6, 2021
daa6110
fix error when enable model parallel
Ldpe2G Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ endif()
option(USE_CLANG_FORMAT "" OFF)
option(BUILD_RDMA "" OFF)
option(BUILD_CUDA "" ON)
option(BUILD_HIP "" OFF)
option(BUILD_TESTING "" OFF)
option(WITH_XLA "Option to build with XLA" OFF)
option(WITH_TENSORRT "Option to build with TensorRT" OFF)
Expand All @@ -31,6 +32,10 @@ set(RPC_BACKEND "GRPC,LOCAL" CACHE STRING "")
set(THIRD_PARTY_MIRROR "" CACHE STRING "")
set(PIP_INDEX_MIRROR "" CACHE STRING "")

if (BUILD_CUDA AND BUILD_HIP)
message(FATAL_ERROR "Enable cuda and hip simultaneously are not supported.")
endif()

if (APPLE)
set(RPC_BACKEND "LOCAL")
set(BUILD_CUDA OFF)
Expand All @@ -45,6 +50,11 @@ endif()

project(oneflow C CXX)

if (BUILD_HIP)
# Search for rocm in common locations
list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
endif()

set(oneflow_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)

get_filename_component(real_src_dir "${CMAKE_SOURCE_DIR}" REALPATH)
Expand Down
2 changes: 1 addition & 1 deletion cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ add_dependencies(of_pyext_obj of_ccobj)
if(APPLE)
set(of_libs -Wl,-force_load ${ONEFLOW_CUDA_LIBS} of_ccobj of_protoobj of_cfgobj)
elseif(UNIX)
set(of_libs -Wl,--whole-archive ${ONEFLOW_CUDA_LIBS} of_ccobj of_protoobj of_cfgobj -Wl,--no-whole-archive -ldl -lrt)
set(of_libs -Wl,--whole-archive ${ONEFLOW_CUDA_LIBS} of_ccobj of_protoobj of_cfgobj -Wl,--no-whole-archive -ldl -lrt -ludev)
elseif(WIN32)
set(of_libs of_ccobj of_protoobj of_cfgobj)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:of_ccobj")
Expand Down
20 changes: 20 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,26 @@ if (BUILD_CUDA)
)
endif()

if (BUILD_HIP)
# Find hip
find_package(hip)
find_package(hipblas)
find_package(hipcub)
find_package(rocrand)
find_package(hiprand)
find_package(miopen)
find_package(rccl)
add_definitions(-DWITH_HIP)
list(APPEND oneflow_third_party_libs hip::device)
list(APPEND oneflow_third_party_libs roc::hipblas)
list(APPEND oneflow_third_party_libs hip::hipcub)
list(APPEND oneflow_third_party_libs roc::rocrand)
list(APPEND oneflow_third_party_libs hip::hiprand)
list(APPEND oneflow_third_party_libs MIOpen)
link_directories(/opt/rocm/rccl/lib)
list(APPEND oneflow_third_party_libs rccl)
endif()

if(BUILD_RDMA)
if(UNIX)
include(CheckIncludeFiles)
Expand Down
4 changes: 2 additions & 2 deletions oneflow/api/python/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("flags", m) {
m.def("with_cuda", []() {
#ifdef WITH_CUDA
#if defined (WITH_CUDA) || defined (WITH_HIP)
return true;
#else
return false;
#endif // WITH_CUDA
#endif // defined (WITH_CUDA) || defined (WITH_HIP)
});

m.def("use_cxx11_abi", []() {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/actor/acc_compute_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void AccCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt) {
if (GetDeviceType() == DeviceType::kCPU) {
cpy_func_ = std::bind(Memcpy<DeviceType::kCPU>, _1, _2, _3, _4);
} else {
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)
cpy_func_ = std::bind(Memcpy<DeviceType::kGPU>, _1, _2, _3, _4);
#else
UNIMPLEMENTED();
Expand All @@ -80,7 +80,7 @@ void AccCompActor::Act() {
Memcpy<DeviceType::kCPU>(kernel_ctx.device_ctx, out_blob->ForceMutDptr(), in_blob->dptr(),
out_blob->ByteSizeOfBlobBody());
} else if (GetDeviceType() == DeviceType::kGPU) {
#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)
Memcpy<DeviceType::kGPU>(kernel_ctx.device_ctx, out_blob->ForceMutDptr(), in_blob->dptr(),
out_blob->ByteSizeOfBlobBody());
#else
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/actor/copy_hd_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.

namespace oneflow {

#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)

void CopyHdActor::VirtualActorInit(const TaskProto& task_proto) {
OF_SET_MSG_HANDLER(&CopyHdActor::HandlerNormal);
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/actor/copy_hd_actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.

namespace oneflow {

#ifdef WITH_CUDA
#if defined(WITH_CUDA) || defined(WITH_HIP)

class CopyHdActor final : public Actor {
public:
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/common/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void CheckDataType() {
static_assert(sizeof(int32_t) == sizeof(int), "sizeof(int32_t) != sizeof(int)");
static_assert(sizeof(int64_t) == sizeof(long long), "sizeof(int64_t) != sizeof(long long)");

#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)

#define CHECK_DEVICE_FP16(get_val) \
do { \
Expand Down
9 changes: 6 additions & 3 deletions oneflow/core/common/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License.
#if defined(WITH_CUDA)
#include <cuda_fp16.h>
#endif
#if defined(WITH_HIP)
#include <hip/hip_fp16.h>
#endif
#include "oneflow/core/common/fp16_data_type.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/data_type_seq.h"
Expand All @@ -29,7 +32,7 @@ limitations under the License.

namespace oneflow {

#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)
#define DEVICE_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU) \
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU)
Expand Down Expand Up @@ -81,7 +84,7 @@ struct GetDataType<T, typename std::enable_if<IsFloat16<T>::value>::type>
template<DataType type>
using DataTypeToType = decltype(GetTypeByDataType(std::integral_constant<DataType, type>{}));

#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_HCC__)
#define OF_DEVICE_FUNC __device__ __host__ __forceinline__
#else
#define OF_DEVICE_FUNC inline
Expand Down Expand Up @@ -198,7 +201,7 @@ struct DevDType {
typedef T type;
};

#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)
template<>
struct DevDType<DeviceType::kGPU, float16> {
static_assert(sizeof(float16) == sizeof(half), "sizeof(float16) != sizeof(half)");
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/common/data_type_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ limitations under the License.

#define FLOAT16_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16)

#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_HIP)
#define HALF_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16)
#endif

Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,15 @@ int GetCudaPtxVersion();

#else

#if !defined(WITH_HIP)
namespace oneflow {

enum class CudaWorkType {};

inline size_t GetCudaWorkTypeSize() { return 0; }

} // namespace oneflow
#endif

#endif // WITH_CUDA

Expand Down
9 changes: 9 additions & 0 deletions oneflow/core/device/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_DEVICE_DEVICE_CONTEXT_H_

#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/device/hip_util.hip.h"
#include "oneflow/core/common/auto_registration_factory.h"

namespace oneflow {
Expand All @@ -40,6 +41,14 @@ class DeviceCtx {
virtual const cudnnHandle_t& cudnn_handle() const { UNIMPLEMENTED(); }
#endif

#ifdef WITH_HIP
virtual const hipStream_t& hip_stream() const { UNIMPLEMENTED(); }
virtual const hipblasHandle_t& hipblas_pmh_handle() const { UNIMPLEMENTED(); }
virtual const hipblasHandle_t& hipblas_pmd_handle() const { UNIMPLEMENTED(); }
virtual const hipblasHandle_t& hipblas_tensor_op_math_handle() const { UNIMPLEMENTED(); }
virtual const miopenHandle_t& miopen_handle() const { UNIMPLEMENTED(); }
#endif

virtual void SyncDevice() { UNIMPLEMENTED(); }
virtual void AddCallBack(std::function<void()>) const { UNIMPLEMENTED(); }

Expand Down
31 changes: 31 additions & 0 deletions oneflow/core/device/hip_device_context.hip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/device/hip_device_context.hip.h"
#include "oneflow/core/thread/thread_context.h"

namespace oneflow {

#ifdef WITH_HIP

REGISTER_DEVICE_CONTEXT(DeviceType::kGPU, ([](const ThreadCtx& thread_ctx) -> DeviceCtx* {
HipStreamHandle* hip_handle = nullptr;
hip_handle = thread_ctx.g_hip_stream.get();
return new HipDeviceCtx(hip_handle);
}));

#endif // WITH_HIP

} // namespace oneflow
63 changes: 63 additions & 0 deletions oneflow/core/device/hip_device_context.hip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_DEVICE_HIP_DEVICE_CONTEXT_H_
#define ONEFLOW_CORE_DEVICE_HIP_DEVICE_CONTEXT_H_

#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/device/hip_stream_handle.hip.h"

namespace oneflow {

#ifdef WITH_HIP

class HipDeviceCtx : public DeviceCtx {
public:
OF_DISALLOW_COPY_AND_MOVE(HipDeviceCtx);
HipDeviceCtx() = delete;
~HipDeviceCtx() override = default;

explicit HipDeviceCtx(HipStreamHandle* hip_handler) : hip_handler_(hip_handler) {}

const hipStream_t& hip_stream() const override { return *(hip_handler_->hip_stream()); }

const hipblasHandle_t& hipblas_pmh_handle() const override {
return *(hip_handler_->hipblas_pmh_handle());
}
const hipblasHandle_t& hipblas_tensor_op_math_handle() const override {
return *(hip_handler_->hipblas_tensor_op_math_handle());
}
const hipblasHandle_t& hipblas_pmd_handle() const override {
return *(hip_handler_->hipblas_pmd_handle());
}

const miopenHandle_t& miopen_handle() const override { return *(hip_handler_->miopen_handle()); }

void SyncDevice() override { OF_HIP_CHECK(hipStreamSynchronize(hip_stream())); }

void AddCallBack(std::function<void()> callback) const override {
hip_handler_->AddCallBack(callback);
}

protected:
HipStreamHandle* hip_handler_;
};

#endif // WITH_HIP

} // namespace oneflow

#endif // ONEFLOW_CORE_DEVICE_HIP_DEVICE_CONTEXT_H_
Loading