forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCReduceApplyUtils.cu
128 lines (102 loc) · 3.53 KB
/
THCReduceApplyUtils.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "THCReduceApplyUtils.cuh"
#include <assert.h>
#include <stdlib.h>
// Maximum size per grid dimension that we assume (compute capability >= 2.0)
#define MAX_GRID_SIZE 65535L
void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg) {
long dims = THCudaTensor_nDimension(state, tensor);
THArgCheck(dims <= MAX_CUTORCH_DIMS, arg, CUTORCH_DIM_WARNING);
}
bool THC_canUse32BitIndexMath(THCState* state, THCudaTensor* t) {
long elements = THCudaTensor_nElement(state, t);
if (elements >= UINT_MAX) {
return false;
}
long offset = 0;
long linearId = elements - 1;
for (int i = THCudaTensor_nDimension(state, t) - 1; i >= 0; --i) {
long curDimIndex = linearId % THCudaTensor_size(state, t, i);
long curDimOffset = curDimIndex * THCudaTensor_stride(state, t, i);
offset += curDimOffset;
linearId /= THCudaTensor_size(state, t, i);
}
if (offset >= UINT_MAX) {
return false;
}
return true;
}
bool THC_getGridFromTiles(long gridTiles, dim3& grid) {
if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
return false;
}
long gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
long gridY = 1;
long gridZ = 1;
if (gridTiles > MAX_GRID_SIZE) {
gridTiles = THCCeilDiv(gridTiles, MAX_GRID_SIZE);
gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
if (gridTiles > MAX_GRID_SIZE) {
gridTiles = THCCeilDiv(gridTiles, MAX_GRID_SIZE);
gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
}
}
grid = dim3(gridX, gridY, gridZ);
return true;
}
namespace {
struct SizeAndStride {
long size;
long stride;
};
int compareSizeAndStride(const void* a, const void* b) {
const SizeAndStride* aS = (const SizeAndStride*) a;
const SizeAndStride* bS = (const SizeAndStride*) b;
return aS->stride < bS->stride;
}
}
bool THC_overlappingIndices(THCState* state, THCudaTensor* t) {
// In this function, we don't care about permutations of the
// size/stride arrays (transpositions).
// We order the size/stride arrays by stride, skipping dimensions of
// size 1. Strides of dimensions of size 1 don't matter, since there
// is only one addressing point in them.
// In this reordered view, the tensor is contiguous if
// stride[dim] == size[dim + 1] * stride[dim + 1] for all `dim`.
// The tensor has holes if
// stride[dim] > size[dim + 1] * stride[dim + 1] for one or more
// `dim`.
// The tensor has overlaps if
// stride[dim] < size[dim + 1] * stride[dim + 1] for one or more
// `dim`, or the innermost stride is 0.
// Extract size/stride arrays; only consider size >1 dims.
SizeAndStride info[MAX_CUTORCH_DIMS];
int dims = THCudaTensor_nDimension(state, t);
int nonSize1Dims = 0;
for (int i = 0; i < dims; ++i) {
long size = THCudaTensor_size(state, t, i);
if (size > 1) {
info[nonSize1Dims].size = size;
info[nonSize1Dims].stride = THCudaTensor_stride(state, t, i);
++nonSize1Dims;
}
}
if (nonSize1Dims == 0) {
// no overlap
return false;
}
// Ascending order (innermost dimension in sorted view is at [0])
qsort(info, nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride);
// Base case: innermost dimension must have stride >= 1
if (info[nonSize1Dims - 1].stride < 1) {
return true;
}
// Subsequent dimensions, if any
for (int i = nonSize1Dims - 2; i >= 0; --i) {
if (info[i].stride < info[i + 1].size * info[i + 1].stride) {
// There are overlaps
return true;
}
}
// Tensor has holes or is contiguous
return false;
}