Skip to content

Commit

Permalink
Handle pollution of MAX, MIN and CHECK macros. (pytorch#11805)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#11805

Some of our headers in Caffe2 pollute the macro namespace with things like MAX,
MIN, CHECK, so I renamed these in places where this is a problem.

This patch courtesy of gchanan, extracted out of pytorch#11721

Reviewed By: Yangqing

Differential Revision: D9917757

fbshipit-source-id: 17fc692ca04b208dcb8ae00731ed60e393284f7c
  • Loading branch information
ezyang authored and facebook-github-bot committed Sep 18, 2018
1 parent 9eb7288 commit 1d399a8
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 23 deletions.
2 changes: 2 additions & 0 deletions aten/src/TH/THTensorCopy.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "THTensor.hpp"
#include "THVector.h"

#include <algorithm>

#include "generic/THTensorCopy.cpp"
#include "THGenerateAllTypes.h"

Expand Down
18 changes: 7 additions & 11 deletions aten/src/TH/generic/THTensorCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
// special case copy where tensor is contiguous and src is a transposed matrix
// This can be generalized to most copies, but it's tricker
void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))

#ifdef TH_REAL_IS_BYTE
const int BLOCK_SZ = 120;
const int64_t BLOCK_SZ = 120;
#else
const int BLOCK_SZ = 60;
const int64_t BLOCK_SZ = 60;
#endif

THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
Expand All @@ -48,19 +46,19 @@ void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
scalar_t *spo = sp + R + C * NR;
scalar_t *rpo = rp + C + R * NC;

int nr = MIN(NR - R, BLOCK_SZ);
int nc = MIN(NC - C, BLOCK_SZ);
int nr = std::min(NR - R, BLOCK_SZ);
int nc = std::min(NC - C, BLOCK_SZ);

// 1. copy columns from src to buf
for (int c = 0; c < nc; c++) {
memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
}

// 2. transpose buf in place
int rc_max = MAX(nr, nc);
int rc_min = MIN(nr, nc);
int rc_max = std::max(nr, nc);
int rc_min = std::min(nr, nc);
for (int r = 0; r < rc_max; r++) {
int end = MIN(r, rc_min);
int end = std::min(r, rc_min);
for (int c = 0; c < end; c++) {
scalar_t tmp = bp[r + BLOCK_SZ * c];
bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
Expand All @@ -75,8 +73,6 @@ void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
}
}
c10::raw::intrusive_ptr::decref(buf);
#undef MIN
#undef MAX
}

void THTensor_(copy)(THTensor *tensor, THTensor *src)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/cuda/nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct NcclCommList {
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data()));
NCCL_CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {
Expand Down Expand Up @@ -219,7 +219,7 @@ void broadcast(TensorList tensors, const stream_list& streams, const comm_list&
AT_CHECK(static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ", numel, " elements, which exceeds the "
"maximum NCCL supports (", count_max, ")");
CHECK(ncclBcast(tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
NCCL_CHECK(ncclBcast(tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream));
}
#else
throw std::runtime_error("PyTorch built without NCCL support");
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/cuda/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace detail {

void throw_nccl_error(ncclResult_t status);

static inline void CHECK(ncclResult_t status) {
static inline void NCCL_CHECK(ncclResult_t status) {
if (status != ncclSuccess) {
throw_nccl_error(status);
}
Expand All @@ -21,12 +21,12 @@ static inline void CHECK(ncclResult_t status) {
struct AutoNcclGroup {
AutoNcclGroup() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
CHECK(ncclGroupStart());
NCCL_CHECK(ncclGroupStart());
#endif
}
~AutoNcclGroup() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
CHECK(ncclGroupEnd());
NCCL_CHECK(ncclGroupEnd());
#endif
}
};
Expand Down
14 changes: 7 additions & 7 deletions torch/csrc/cuda/python_nccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ PyObject * THCPModule_nccl_version(PyObject *self, PyObject *args) {
PyObject * THCPModule_nccl_unique_id(PyObject *self, PyObject *args) {
HANDLE_TH_ERRORS
ncclUniqueId id;
CHECK(ncclGetUniqueId(&id));
NCCL_CHECK(ncclGetUniqueId(&id));
return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
END_HANDLE_TH_ERRORS
}
Expand Down Expand Up @@ -109,7 +109,7 @@ PyObject * THCPModule_nccl_init_rank(PyObject *self, PyObject *args) {
memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
ncclComm_t comm;
with_no_gil([&]{
CHECK(ncclCommInitRank(&comm, nranks, commId, rank));
NCCL_CHECK(ncclCommInitRank(&comm, nranks, commId, rank));
});
return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
END_HANDLE_TH_ERRORS
Expand Down Expand Up @@ -149,7 +149,7 @@ PyObject * THCPModule_nccl_reduce(PyObject *self, PyObject *args) {
int device = inputs[i].get_device();
device_guard.set_index(device);
auto stream = (streams[i] == nullptr) ? nullptr : THCStream_stream(streams[i]);
CHECK(ncclReduce(inputs[i].data_ptr(), outputs[i].data_ptr(),
NCCL_CHECK(ncclReduce(inputs[i].data_ptr(), outputs[i].data_ptr(),
count, data_type, (ncclRedOp_t) op, root, comms[i], stream));
}
});
Expand Down Expand Up @@ -191,7 +191,7 @@ PyObject * THCPModule_nccl_all_reduce(PyObject *self, PyObject *args) {
int device = inputs[i].get_device();
device_guard.set_index(device);
auto stream = (streams[i] == nullptr) ? nullptr : THCStream_stream(streams[i]);
CHECK(ncclAllReduce(inputs[i].data_ptr(), outputs[i].data_ptr(),
NCCL_CHECK(ncclAllReduce(inputs[i].data_ptr(), outputs[i].data_ptr(),
count, data_type, (ncclRedOp_t) op, comms[i], stream));
}
});
Expand Down Expand Up @@ -255,10 +255,10 @@ PyObject * THCPModule_nccl_all_gather(PyObject *self, PyObject *args) {
device_guard.set_index(device);
auto stream = (streams[i] == nullptr) ? nullptr : THCStream_stream(streams[i]);
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
CHECK(ncclAllGather(inputs[i].data_ptr(), outputs[i].data_ptr(),
NCCL_CHECK(ncclAllGather(inputs[i].data_ptr(), outputs[i].data_ptr(),
count, data_type, comms[i], stream));
#else
CHECK(ncclAllGather(inputs[i].data_ptr(), count, data_type,
NCCL_CHECK(ncclAllGather(inputs[i].data_ptr(), count, data_type,
outputs[i].data_ptr(), comms[i], stream));
#endif
}
Expand Down Expand Up @@ -299,7 +299,7 @@ PyObject * THCPModule_nccl_reduce_scatter(PyObject *self, PyObject *args) {
int device = inputs[i].get_device();
device_guard.set_index(device);
auto stream = (streams[i] == nullptr) ? nullptr : THCStream_stream(streams[i]);
CHECK(ncclReduceScatter(inputs[i].data_ptr(), outputs[i].data_ptr(),
NCCL_CHECK(ncclReduceScatter(inputs[i].data_ptr(), outputs[i].data_ptr(),
count, data_type, (ncclRedOp_t) op, comms[i], stream));
}
});
Expand Down

0 comments on commit 1d399a8

Please sign in to comment.