Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/THC/THCTensorInfo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ struct IndexToOffset {
}
};

// For contiguous tensors, the offset = index
template <typename T, typename IndexType>
struct IndexToOffset<T, IndexType, -2> {
static inline __host__ __device__ IndexType
Expand Down
26 changes: 26 additions & 0 deletions lib/THC/THCTensorTypeUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ TENSOR_UTILS(THCudaHalfTensor, half, float);

#undef TENSOR_UTILS

// Utility function for constructing TensorInfo structs. In this case, the
// two template parameters are:
//
// 1. The TensorType, e.g. THCTensor in generic functions, or THCudaTensor,
// THCudaLongTensor etc.
//
// 2. The IndexType. This is always going to be an unsigned integral value,
// but depending on the size of the Tensor you may select unsigned int,
// unsigned long, unsigned long long etc.
//
// Internally we use the TensorUtils static functions to get the necessary
// dims, sizes, stride etc.
//
// For example, suppose we have a THCudaTensor t, with dim = 2, size = [3, 4],
// stride = [4, 1], offset = 8, and we set our index type to be unsigned int.
// Then we yield a TensorInfo struct templatized with float, unsigned int and
// the following fields:
//
// data is a float* to the underlying storage at position 8
// dims is 2
// sizes is a MAX_CUTORCH_DIMS element array with [3, 4] in its first two positions
// strides is a MAX_CUTORCH_DIMS element array with [4, 1] in its first two positions
//
// TensorInfos can then be passed to CUDA kernels, but we can use the static functions
// defined above to perform Tensor Operations that are appropriate for each
// TensorType.
template <typename TensorType, typename IndexType>
TensorInfo<typename TensorUtils<TensorType>::DataType, IndexType>
getTensorInfo(THCState* state, TensorType* t) {
Expand Down