Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lib/THC/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ __global__ void indexCopySmallIndex(TensorInfo<T, IndexType> dst,
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - TH_INDEX_BASE;

if (dstIndex < dstCopyDimSize) {
// We stride over the output ignoring the indexed dimension
Expand Down Expand Up @@ -78,7 +78,7 @@ __global__ void indexCopyLargeIndex(TensorInfo<T, IndexType> dst,

// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - TH_INDEX_BASE;

if (dstIndex < dstCopyDimSize) {
IndexType dstOffset =
Expand Down Expand Up @@ -116,7 +116,7 @@ __global__ void indexAddSmallIndex(TensorInfo<T, IndexType> dst,
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - TH_INDEX_BASE;

if (dstIndex < dstAddDimSize) {
// We stride over the output ignoring the indexed dimension
Expand Down Expand Up @@ -162,7 +162,7 @@ __global__ void indexAddLargeIndex(TensorInfo<T, IndexType> dst,

// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(srcIndex, indices)] - TH_INDEX_BASE;

if (dstIndex < dstAddDimSize) {
IndexType dstOffset =
Expand Down Expand Up @@ -199,7 +199,7 @@ __global__ void indexFillSmallIndex(TensorInfo<T, IndexType> dst,
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
// Lua indices begin at 1
IndexType dstIndex_ =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - TH_INDEX_BASE;

if (dstIndex < dstFillDimSize) {
// We stride over the output ignoring the indexed dimension
Expand Down Expand Up @@ -240,7 +240,7 @@ __global__ void indexFillLargeIndex(TensorInfo<T, IndexType> dst,

// Lua indices begin at 1
IndexType dstIndex_ =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - TH_INDEX_BASE;

if (dstIndex_ < dstFillDimSize) {
IndexType dstOffset =
Expand Down Expand Up @@ -274,7 +274,7 @@ __global__ void indexSelectSmallIndex(TensorInfo<T, IndexType> dst,
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
// Lua indices begin at 1
IndexType srcIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - TH_INDEX_BASE;

if (srcIndex < srcSelectDimSize) {
// We stride over the output ignoring the indexed dimension
Expand Down Expand Up @@ -321,7 +321,7 @@ __global__ void indexSelectLargeIndex(TensorInfo<T, IndexType> dst,

// Lua indices begin at 1
IndexType srcIndex =
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - 1;
indices.data[IndexToOffset<long, IndexType, IdxDim>::get(dstIndex, indices)] - TH_INDEX_BASE;

if (srcIndex < srcSelectDimSize) {
IndexType dstOffset =
Expand Down
4 changes: 2 additions & 2 deletions lib/THC/THCTensorMathReduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ kernelTransformReduceOuterDimIndex(K *tgt1,

for (unsigned col = 0; col < row_size; ++col) {
// +1 for Lua index
acc = binary_op(thrust::make_pair<K, Index>(*src, col+1),
acc = binary_op(thrust::make_pair<K, Index>(*src, col + TH_INDEX_BASE),
acc);
src += num_irows;
}
Expand Down Expand Up @@ -231,7 +231,7 @@ kernelTransformReduceInnermostDimIndex(K *tgt1,
K *src = src_ + row * row_size;
// Sequential reduction within a thread.
for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) {
acc = binary_op(thrust::make_pair<K, Index>(src[col], col + 1), acc);
acc = binary_op(thrust::make_pair<K, Index>(src[col], col + TH_INDEX_BASE), acc);
}
}

Expand Down
6 changes: 3 additions & 3 deletions lib/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ sampleMultinomialOnce(float* dest,
// We're done; we have the sample
// Torch indices are 1-based
// FIXME: broadcast exit flag?
dest[curDist] = cat + 1;
dest[curDist] = cat + TH_INDEX_BASE;
}

// Store the previous scan's high value for future use
Expand Down Expand Up @@ -555,7 +555,7 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state,
r);

// Torch indices are 1-based
dest[curDist * totalSamples + sample] = (float) choice + 1.0f;
dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE;
}
}
}
Expand Down Expand Up @@ -595,7 +595,7 @@ sampleMultinomialWithoutReplacement(curandStateMtgp32* state,
r);

// Torch indices are 1-based
dest[curDist * totalSamples + sample] = (float) choice + 1.0f;
dest[curDist * totalSamples + sample] = (float) choice + (float)TH_INDEX_BASE;

// Without replacement, so update the original probability so it
// is not considered a second time
Expand Down
6 changes: 3 additions & 3 deletions lib/THC/THCTensorScatterGather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ __global__ void THCudaTensor_gatherKernel(
tensor, &tensorOffset,
src, &srcOffset);

IndexType indexValue = (IndexType)index.data[indexOffset] - 1;
IndexType indexValue = (IndexType)index.data[indexOffset] - TH_INDEX_BASE;
srcOffset += indexValue * src.strides[dim];

tensor.data[tensorOffset] = src.data[srcOffset];
Expand All @@ -118,7 +118,7 @@ __global__ void THCudaTensor_scatterKernel(
src, &srcOffset,
tensor, &tensorOffset);

IndexType indexValue = (IndexType)index.data[indexOffset] - 1;
IndexType indexValue = (IndexType)index.data[indexOffset] - TH_INDEX_BASE;
tensorOffset += indexValue * tensor.strides[dim];

tensor.data[tensorOffset] = src.data[srcOffset];
Expand All @@ -142,7 +142,7 @@ __global__ void THCudaTensor_scatterFillKernel(
index, &indexOffset,
tensor, &tensorOffset);

IndexType indexValue = (IndexType)index.data[indexOffset] - 1;
IndexType indexValue = (IndexType)index.data[indexOffset] - TH_INDEX_BASE;
tensorOffset += indexValue * tensor.strides[dim];

tensor.data[tensorOffset] = value;
Expand Down
4 changes: 2 additions & 2 deletions lib/THC/THCTensorSort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fillSliceWithIndex(TensorInfo<long, IndexType> out,

for (long i = threadIdx.x; i < sliceSize; i += blockDim.x) {
// Torch indices are 1-based (hence the +1)
base[i * sliceStride] = i + 1;
base[i * sliceStride] = i + TH_INDEX_BASE;
}
}

Expand Down Expand Up @@ -145,7 +145,7 @@ struct GlobalIndexToPerSliceIndex {
GlobalIndexToPerSliceIndex(long size) : sliceSize(size) {}

__device__ inline void operator()(long& v) const {
v = v % sliceSize + 1;
v = v % sliceSize + TH_INDEX_BASE;
}

const long sliceSize;
Expand Down
4 changes: 2 additions & 2 deletions lib/THC/THCTensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
IndexType indexOffset = writeIndex * indicesWithinSliceStride;

topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i + 1; // to Lua index
indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
}

writeIndexStart += carry;
Expand Down Expand Up @@ -364,7 +364,7 @@ __global__ void gatherTopK(TensorInfo<float, IndexType> input,
IndexType indexOffset = writeIndex * indicesWithinSliceStride;

topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i + 1; // to Lua index
indicesSliceStart[indexOffset] = i + TH_INDEX_BASE; // to Lua index
}

if (carry >= topKRemaining) {
Expand Down