Skip to content

Commit

Permalink
orlando - for local updates of pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
llv22 committed Jun 30, 2022
1 parent 67ece03 commit b27d168
Show file tree
Hide file tree
Showing 17 changed files with 199 additions and 63 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
cmake_dependent_option(
USE_CUDA_MPI "Force CUDA-Aware MPI for Caffe2. Only available if USE_DISTRIBUTED and USE_MPI is on." OFF
"USE_DISTRIBUTED AND USE_MPI" OFF)
cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
Expand Down
84 changes: 58 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
<!-- markdownlint-disable MD033 -->
<!-- markdownlint-disable MD004 -->
<!-- markdownlint-disable MD029 -->
# pytorch 1.12.0 with Nvidia GPU on macOS

--------------------------------------------------------------------------------
As officially Pytorch doesn't support for macOS cuda, I used this repository to build pytorch on macOS cuda. **This branch 1.12.0-fixed branch is the current stable branch**. Currently MPI+CUDA is still disabled, an ensuing investigation will start later. My gut feeling is that it has been caused by CMakeList.txt setting, which doesn't set MPI and CUDA setting appropriately simultaneously.

- macOS 10.13.6, cuda 10.1, cudnn 7.6.5 (cuda and cudnn is the last official version which Nvidia released to support macOS)
- [NCCL on macOS 2.9.6.1](https://github.com/llv22/nccl-osx) and [test suite](https://github.com/llv22/nccl-tests-macOS-cuda)
- Xcode 10.1, libuv 1.2.6
- magma 2.6 built on macOS, providing by [cloned magma repository from The University of Tennessee, Knoxville](https://github.com/llv22/magma-macOS)
- support distributed options with TENSORPIPE, which has been fixed via [Orlando's tensorpipe](https://github.com/llv22/tensorpipe-macos-cuda/tree/torch-1.11.0)

```bash
-- USE_DISTRIBUTED : ON
-- USE_MPI : ON
-- USE_GLOO : ON
-- USE_TENSORPIPE : ON
```

Consolidating [torch-1.11.0-mac.patch](https://github.com/llv22/pytorch-macOS-cuda/blob/v1.12.1-fixed/mac-with-tensorpipe-cuda-mpi-enabling.patch) by

```bash
git format-patch -2 --stdout > torch-1.12.0-mac-with-tensorpipe-cuda-mpi-enabling.patch
```

+refer to <https://www.ivankristianto.com/create-patch-files-from-multiple-commits-in-git/>

--------------------------------------------------------------------------------

![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png)

--------------------------------------------------------------------------------
Expand All @@ -12,32 +43,33 @@ Our trunk health (Continuous Integration signals) can be found at [hud.pytorch.o

<!-- toc -->

- [More About PyTorch](#more-about-pytorch)
- [A GPU-Ready Tensor Library](#a-gpu-ready-tensor-library)
- [Dynamic Neural Networks: Tape-Based Autograd](#dynamic-neural-networks-tape-based-autograd)
- [Python First](#python-first)
- [Imperative Experiences](#imperative-experiences)
- [Fast and Lean](#fast-and-lean)
- [Extensions Without Pain](#extensions-without-pain)
- [Installation](#installation)
- [Binaries](#binaries)
- [NVIDIA Jetson Platforms](#nvidia-jetson-platforms)
- [From Source](#from-source)
- [Install Dependencies](#install-dependencies)
- [Get the PyTorch Source](#get-the-pytorch-source)
- [Install PyTorch](#install-pytorch)
- [Adjust Build Options (Optional)](#adjust-build-options-optional)
- [Docker Image](#docker-image)
- [Using pre-built images](#using-pre-built-images)
- [Building the image yourself](#building-the-image-yourself)
- [Building the Documentation](#building-the-documentation)
- [Previous Versions](#previous-versions)
- [Getting Started](#getting-started)
- [Resources](#resources)
- [Communication](#communication)
- [Releases and Contributing](#releases-and-contributing)
- [The Team](#the-team)
- [License](#license)
- [pytorch 1.12.0 with Nvidia GPU on macOS](#pytorch-1120-with-nvidia-gpu-on-macos)
- [More About PyTorch](#more-about-pytorch)
- [A GPU-Ready Tensor Library](#a-gpu-ready-tensor-library)
- [Dynamic Neural Networks: Tape-Based Autograd](#dynamic-neural-networks-tape-based-autograd)
- [Python First](#python-first)
- [Imperative Experiences](#imperative-experiences)
- [Fast and Lean](#fast-and-lean)
- [Extensions Without Pain](#extensions-without-pain)
- [Installation](#installation)
- [Binaries](#binaries)
- [NVIDIA Jetson Platforms](#nvidia-jetson-platforms)
- [From Source](#from-source)
- [Install Dependencies](#install-dependencies)
- [Get the PyTorch Source](#get-the-pytorch-source)
- [Install PyTorch](#install-pytorch)
- [Adjust Build Options (Optional)](#adjust-build-options-optional)
- [Docker Image](#docker-image)
- [Using pre-built images](#using-pre-built-images)
- [Building the image yourself](#building-the-image-yourself)
- [Building the Documentation](#building-the-documentation)
- [Previous Versions](#previous-versions)
- [Getting Started](#getting-started)
- [Resources](#resources)
- [Communication](#communication)
- [Releases and Contributing](#releases-and-contributing)
- [The Team](#the-team)
- [License](#license)

<!-- tocstop -->

Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,12 @@ if(USE_TBB)
list(APPEND ATen_CPU_DEPENDENCY_LIBS TBB::tbb)
endif()

if(USE_OPENMP)
message("ATen is compiled with OPEN_MP (/Users/llv23/opt/miniconda3/lib/libomp.dylib)")
list(APPEND ATen_CPU_DEPENDENCY_LIBS /Users/llv23/opt/miniconda3/lib/libomp.dylib)
endif()


if(BLAS_FOUND)
if($ENV{TH_BINARY_BUILD})
message(STATUS "TH_BINARY_BUILD detected. Enabling special linkage.")
Expand Down Expand Up @@ -421,22 +427,26 @@ if(USE_CUDA AND NOT USE_ROCM)
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusparse_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcurand_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a
/Users/llv23/opt/miniconda3/lib/libomp.dylib # test parallel for symbol _omp_in_parallel
)
if(NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static
/Users/llv23/opt/miniconda3/lib/libomp.dylib # test parallel for symbol _omp_in_parallel
)
endif()
else()
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}
${CUDA_cusparse_LIBRARY}
${CUDA_curand_LIBRARY}
/Users/llv23/opt/miniconda3/lib/libomp.dylib # test parallel for symbol _omp_in_parallel
)
if(NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_cusolver_LIBRARY}
/Users/llv23/opt/miniconda3/lib/libomp.dylib # test parallel for symbol _omp_in_parallel
)
endif()
endif()
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,14 @@ template<typename T>
inline typename std::enable_if<!std::is_integral<T>::value, bool>::type isnan_(T x) {
return std::isnan(x);
}
#elif defined(__APPLE__) && defined(__MACH__)
template<typename T>
inline bool isnan_(T x) {
return std::isnan(x);
}
inline bool isnan_(const c10::BFloat16 x) {
return std::isnan(x.x);
}
#else
template<typename T>
inline bool isnan_(T x) {
Expand Down
16 changes: 15 additions & 1 deletion aten/src/ATen/native/cpu/AdaptiveMaxPoolKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ namespace at { namespace native {

namespace {

#if defined(__APPLE__) && defined(__MACH__)
template<typename T>
inline bool isnan_(T x) {
return std::isnan(x);
}
inline bool isnan_(const c10::BFloat16 x) {
return std::isnan(x.x);
}
#endif

template <typename scalar_t, typename accscalar_t>
void cpu_adaptive_max_pool(
const Tensor& output_,
Expand Down Expand Up @@ -56,7 +66,11 @@ void cpu_adaptive_max_pool(
for (int64_t iw = iw0; iw < iw1; iw ++) {
int64_t index = ih * input_width + iw;
scalar_t val = input_ptr[index];
if ((val > maxval) || std::isnan(val)) {
#if defined(__APPLE__) && defined(__MACH__)
if ((val > maxval) || isnan_(val)) {
#else
if ((val > maxval) || std::isnan(val)) {
#endif
maxval = val;
maxindex = index;
}
Expand Down
16 changes: 15 additions & 1 deletion aten/src/ATen/native/cpu/MaxPoolKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ namespace at { namespace native {

namespace {

#if defined(__APPLE__) && defined(__MACH__)
template<typename T>
inline bool isnan_(T x) {
return std::isnan(x);
}
inline bool isnan_(const c10::BFloat16 x) {
return std::isnan(x.x);
}
#endif

template <typename scalar_t, typename accscalar_t>
void cpu_max_pool(
const Tensor& output_,
Expand Down Expand Up @@ -64,7 +74,11 @@ void cpu_max_pool(
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
int64_t index = ih * input_width + iw;
accscalar_t val = accscalar_t(input_ptr[index]);
if ((val > maxval) || std::isnan(val)) {
#if defined(__APPLE__) && defined(__MACH__)
if ((val > maxval) || isnan_(val)) {
#else
if ((val > maxval) || std::isnan(val)) {
#endif
maxval = val;
maxindex = index;
}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ Tensor embedding_bag_backward_cuda_sum_avg(
Tensor count;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
auto range = at::arange(num_indices, indices.options());
//https://github.com/pytorch/pytorch/issues/42271
auto range = at::arange(c10::Scalar((int64_t)num_indices), indices.options());
// int64_t nbits = cuda::cub::get_num_bits(num_weights);
cuda::cub::radix_sort_pairs(
indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
Expand Down
12 changes: 9 additions & 3 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND Caffe2_GPU_INCLUDE ${ATen_CUDA_INCLUDE})
list(APPEND Caffe2_HIP_INCLUDE ${ATen_HIP_INCLUDE})
list(APPEND Caffe2_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE})
list(APPEND Caffe2_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS})
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS})
list(APPEND Caffe2_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} /Users/llv23/opt/miniconda3/lib/libomp.dylib)
list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} /Users/llv23/opt/miniconda3/lib/libomp.dylib)
list(APPEND Caffe2_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS})
list(APPEND Caffe2_DEPENDENCY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE})
endif()
Expand Down Expand Up @@ -1362,6 +1362,9 @@ if(USE_DISTRIBUTED)
"${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupMPI.cpp"
PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
endif()
if(USE_CUDA_MPI)
add_definitions(-DUSE_CUDA_MPI=1)
endif()
target_compile_definitions(torch_cpu PUBLIC USE_C10D_MPI)
endif()
# Pass USE_RPC in order to reduce use of
Expand Down Expand Up @@ -1814,7 +1817,7 @@ if(BUILD_TEST)
foreach(test_src ${Caffe2_CPU_TEST_SRCS})
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
target_link_libraries(${test_name} torch_library gtest_main)
target_link_libraries(${test_name} torch_library gtest_main /Users/llv23/opt/miniconda3/lib/libomp.dylib)
if(USE_OPENMP)
# -fopenmp is a compile time flag and as result not guaranteed
# to link executable against OpenMP runtime library
Expand Down Expand Up @@ -2001,6 +2004,9 @@ if(BUILD_PYTHON)
target_compile_definitions(torch_python PRIVATE BUILD_CAFFE2)
if(USE_NUMPY)
target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY")
# Orlando; refer to how to fix issue: ../caffe2/python/pybind_state.h:27:10: fatal error: 'numpy/arrayobject.h' file not found
find_package(Python3 REQUIRED COMPONENTS NumPy)
target_include_directories(caffe2_pybind11_state PRIVATE ${Python3_NumPy_INCLUDE_DIRS})
target_link_libraries(caffe2_pybind11_state PRIVATE numpy::numpy)
endif()
if(NOT MSVC)
Expand Down
2 changes: 2 additions & 0 deletions caffe2/core/macros.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ static_assert(

#cmakedefine CAFFE2_BUILD_SHARED_LIBS
#cmakedefine CAFFE2_FORCE_FALLBACK_CUDA_MPI
#cmakedefine CAFFE2_USE_CUDA_MPI
#cmakedefine CAFFE2_HAS_MKL_DNN
#cmakedefine CAFFE2_HAS_MKL_SGEMM_PACK
#cmakedefine CAFFE2_PERF_WITH_AVX
Expand Down Expand Up @@ -64,6 +65,7 @@ static_assert(
{"CUDNN_VERSION", "${CUDNN_VERSION}"}, \
{"USE_NCCL", "${USE_NCCL}"}, \
{"USE_MPI", "${USE_MPI}"}, \
{"USE_CUDA_MPI", "${USE_CUDA_MPI}"}, \
{"USE_GFLAGS", "${USE_GFLAGS}"}, \
{"USE_GLOG", "${USE_GLOG}"}, \
{"USE_GLOO", "${USE_GLOI}"}, \
Expand Down
25 changes: 24 additions & 1 deletion caffe2/mpi/mpi_ops_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,22 @@ namespace caffe2 {
#define CAFFE2_HAS_CUDA_MPI_ALLREDUCE 0
#endif // CAFFE2_OMPI_VERSION >= 10805
#endif // CAFFE2_OMPI_VERSION >= 2000
#else // !OPEN_MPI
#elif MVAPICH2_NUMVERSION // !OPEN_MPI
#define CAFFE2_MV2_VERSION MVAPICH2_NUMVERSION
#if CAFFE2_MV2_VERSION >= 20305300
#include "mpi-ext.h"
#if MPIX_CUDA_AWARE_SUPPORT
#define CAFFE2_HAS_CUDA_MPI_BASICS 1
#define CAFFE2_HAS_CUDA_MPI_ALLREDUCE 1
#endif // MPIX_CUDA_AWARE_SUPPORT
#else //CAFFE2_MV2_VERSION >= 235
// In the case of MVAPICH2-GDR before 2.3.5, we don't have compile-time flags
// // to figure out if CUDA is supported; as a result, we will assume that the
// // user has built MVAPICH2-GDR with CUDA support.
#define CAFFE2_HAS_CUDA_MPI_BASICS 1
#define CAFFE2_HAS_CUDA_MPI_ALLREDUCE 1
#endif //CAFFE2_MV2_VERSION >= 235
#else // !OPEN_MPI && !MVAPICH_GDR
// We have not really tested against other MPI environments, so let's go for a
// safe path and basically say we don't have cuda-aware functions.
#define CAFFE2_HAS_CUDA_MPI_BASICS 0
Expand All @@ -49,6 +64,14 @@ namespace caffe2 {
#define CAFFE2_HAS_CUDA_MPI_ALLREDUCE 0
#endif // CAFFE2_FORCE_FALLBACK_CUDA_MPI

// We allow a macro to force using CUDA functions
#ifdef CAFFE2_USE_CUDA_MPI
#undef CAFFE2_HAS_CUDA_MPI_BASICS
#undef CAFFE2_HAS_CUDA_MPI_ALLREDUCE
#define CAFFE2_HAS_CUDA_MPI_BASICS 1
#define CAFFE2_HAS_CUDA_MPI_ALLREDUCE 1
#endif // CAFFE2_FORCE_CUDA_MPI

REGISTER_CUDA_OPERATOR(
MPICreateCommonWorld,
MPICreateCommonWorldOp<CUDAContext>);
Expand Down
30 changes: 27 additions & 3 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1134,11 +1134,35 @@ if(USE_MPI)
execute_process(COMMAND ${OMPI_INFO}
OUTPUT_VARIABLE _output)
if(_output MATCHES "smcuda")
message(STATUS "Found OpenMPI with CUDA support built.")
else()
message(WARNING "OpenMPI found, but it is not built with CUDA support.")
set(CAFFE2_FORCE_FALLBACK_CUDA_MPI 1)
if(USE_CUDA_MPI)
if(USE_CUDA)
message(WARNING "OpenMPI with CUDA support not found, but forcing anyway.")
else()
message(WARNING "Force building for OpenMPI with CUDA.")
endif()
set(CAFFE2_USE_CUDA_MPI 1)
else()
message(WARNING "OpenMPI found, but it is not built with CUDA support.")
set(CAFFE2_FORCE_FALLBACK_CUDA_MPI 1)
endif()
endif()
else()
find_program(MV2_INFO NAMES mpiname HINTS ${MPI_CXX_LIBRARIES}/../bin)
if(MV2_INFO)
execute_process(COMMAND ${MV2_INFO} "-a" OUTPUT_VARIABLE _output)
if(_output MATCHES "enable-cuda")
message(STATUS "Found MVAPICH2 with CUDA support built.")
else()
if(USE_CUDA_MPI)
message(WARNING "MVAPICH2 with CUDA support not found, but forcing anyway.")
set(CAFFE2_USE_CUDA_MPI 1)
else()
message(WARNING "MVAPICH2 found, but it is not built with CUDA SUPPORT.")
set(CAFFE2_FORCE_FALLBACK_CUDA_MPI 1)
endif()
endif()
endif()
endif()
else()
message(WARNING "Not compiling with MPI. Suppress this warning with -DUSE_MPI=OFF")
Expand Down
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_DISTRIBUTED : ${USE_DISTRIBUTED}")
if(${USE_DISTRIBUTED})
message(STATUS " USE_MPI : ${USE_MPI}")
message(STATUS " USE_CUDA_MPI : ${USE_CUDA_MPI}")
message(STATUS " USE_GLOO : ${USE_GLOO}")
message(STATUS " USE_GLOO_WITH_OPENSSL : ${USE_GLOO_WITH_OPENSSL}")
message(STATUS " USE_TENSORPIPE : ${USE_TENSORPIPE}")
Expand Down
4 changes: 2 additions & 2 deletions cmake/public/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
message(STATUS "Caffe2: CUDA detected: " ${CUDA_VERSION})
message(STATUS "Caffe2: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE})
message(STATUS "Caffe2: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR})
if(CUDA_VERSION VERSION_LESS 10.2)
message(FATAL_ERROR "PyTorch requires CUDA 10.2 or above.")
if(CUDA_VERSION VERSION_LESS 10.1)
message(FATAL_ERROR "PyTorch requires CUDA 10.1 or above.")
endif()

if(CUDA_FOUND)
Expand Down
6 changes: 5 additions & 1 deletion modules/detectron/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ file(GLOB_RECURSE Detectron_HIP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.hip)

if(BUILD_CAFFE2_OPS)
if(USE_OPENMP AND OPENMP_FOUND)
Set(OpenMP_link ${OpenMP_CXX_LIBRARIES})
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
Set(OpenMP_link -Xpreprocessor -fopenmp /Users/llv23/opt/miniconda3/lib/libomp.dylib /Users/llv23/opt/miniconda3/lib/libgomp.dylib)
else()
Set(OpenMP_link ${OpenMP_CXX_LIBRARIES})
endif()
endif()

# Note(ilijar): Since Detectron ops currently have no
Expand Down
Loading

0 comments on commit b27d168

Please sign in to comment.