Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 150 additions & 4 deletions lib/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ struct TensorAddConstantOp {
#ifdef CUDA_HALF_TENSOR
template <>
struct TensorAddConstantOp<half> {
#ifdef CUDA_HALF_INSTRUCTIONS
TensorAddConstantOp(half v) : val(v) {}
#else
TensorAddConstantOp(half v) : fval(THC_half2float(v)) {}
#endif

__device__ __forceinline__ void operator()(half* out, half* in) {
#ifdef CUDA_HALF_INSTRUCTIONS
*out = __hadd(*in, val);
#else
float fin = __half2float(*in);
float fval = __half2float(val);
float fout = fin + fval;
*out = __float2half(fout);
#endif
Expand All @@ -39,16 +43,73 @@ struct TensorAddConstantOp<half> {
*v = __hadd(*v, val);
#else
float fv = __half2float(*v);
float fval = __half2float(val);
fv += fval;
*v = __float2half(fv);
#endif
}

#ifdef CUDA_HALF_INSTRUCTIONS
const half val;
#else
const float fval;
#endif
};
#endif // CUDA_HALF_TENSOR


template <typename T>
struct TensorSubConstantOp {
TensorSubConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in - val;
}

__device__ __forceinline__ void operator()(T* v) {
*v -= val;
}

const T val;
};


#ifdef CUDA_HALF_TENSOR
template <>
struct TensorSubConstantOp<half> {
#ifdef CUDA_HALF_INSTRUCTONS
TensorSubConstantOp(half v): val(THC_float2half(-(THC_half2float(v)))) {}
#else
TensorSubConstantOp(half v): fval(-(THC_half2float(v))) {}
#endif

__device__ __forceinline__ void operator()(half* out, half* in) {
#ifdef CUDA_HALF_INSTRUCTIONS
*out = __hadd(*in, val);
#else
float fin = __half2float(*in);
float fout = fin + fval;
*out = __float2half(fout);
#endif
}

__device__ __forceinline__ void operator()(half* v) {
#ifdef CUDA_HALF_INSTRUCTIONS
*v = __hadd(*v, val);
#else
float fv = __half2float(*v);
fv += fval;
*v = __float2half(fv);
#endif
}

#ifdef CUDA_HALF_INSTRUCTIONS
const half val;
#else
const float fval;
#endif
};
#endif // CUDA_HALF_TENSOR


template <typename T>
struct TensorMulConstantOp {
TensorMulConstantOp(T v) : val(v) {}
Expand All @@ -66,13 +127,17 @@ struct TensorMulConstantOp {
#ifdef CUDA_HALF_TENSOR
template <>
struct TensorMulConstantOp<half> {
#ifdef CUDA_HALF_INSTRUCTIONS
TensorMulConstantOp(half v) : val(v) {}
#else
TensorMulConstantOp(half v) : fval(THC_half2float(v)) {}
#endif

__device__ __forceinline__ void operator()(half* out, half* in) {
#ifdef CUDA_HALF_INSTRUCTIONS
*out = __hmul(*in, val);
#else
float fin = __half2float(*in);
float fval = __half2float(val);
float fout = fin * fval;
*out = __float2half(fout);
#endif
Expand All @@ -83,13 +148,94 @@ struct TensorMulConstantOp<half> {
*v = __hmul(*v, val);
#else
float fv = __half2float(*v);
float fval = __half2float(val);
fv *= fval;
*v = __float2half(fv);
#endif
}

#ifdef CUDA_HALF_INSTRUCTIONS
const half val;
#else
const float fval;
#endif
};
#endif // CUDA_HALF_TENSOR

template <typename T>
struct TensorDivConstantOp {
TensorDivConstantOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in / val;
}

__device__ __forceinline__ void operator()(T* v) {
*v /= val;
}

const T val;
};

template <>
struct TensorDivConstantOp<float> {
TensorDivConstantOp(float v) : val(1.f / v) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = *in * val;
}

__device__ __forceinline__ void operator()(float* v) {
*v *= val;
}

const float val;
};

template <>
struct TensorDivConstantOp<double> {
TensorDivConstantOp(double v) : val(1. / v) {}
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = *in * val;
}

__device__ __forceinline__ void operator()(double* v) {
*v *= val;
}

const double val;
};

#ifdef CUDA_HALF_TENSOR
template <>
struct TensorDivConstantOp<half> {
#ifdef CUDA_HALF_INSTRUCTIONS
TensorDivConstantOp(half v) : val(ScalarInv<half>::to(v)) {}
#else
TensorDivConstantOp(half v) : fval(1.f / THC_half2float(v)) {}
#endif
__device__ __forceinline__ void operator()(half* out, half* in) {
#ifdef CUDA_HALF_INSTRUCTIONS
*out = __hmul(*in, val);
#else
float fin = __half2float(*in);
float fout = fin * fval;
*out = __float2half(fout);
#endif
}

__device__ __forceinline__ void operator()(half* v) {
#ifdef CUDA_HALF_INSTRUCTIONS
*v = __hmul(*v, val);
#else
float fv = __half2float(*v);
fv *= fval;
*v = __float2half(fv);
#endif
}

#ifdef CUDA_HALF_INSTRUCTIONS
const half val;
#else
const float fval;
#endif
};
#endif // CUDA_HALF_TENSOR

Expand Down
23 changes: 16 additions & 7 deletions lib/THC/generic/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ THCTensor_(add)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
THC_API void
THCTensor_(sub)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
{
THCTensor_(add)(state, self_, src_, ScalarNegate<real>::to(value));
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
if (self_ == src_) {
if (!THC_pointwiseApply1(state, self_, TensorSubConstantOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src_);

if (!THC_pointwiseApply2(state, self_, src_, TensorSubConstantOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}

THCudaCheck(cudaGetLastError());
}

THC_API void
Expand Down Expand Up @@ -53,17 +66,13 @@ THCTensor_(div)(THCState* state, THCTensor *self_, THCTensor *src_, real value)
THArgCheck(value != ScalarConvert<int, real>::to(0), 3, "divide by zero");

if (self_ == src_) {
if (!THC_pointwiseApply1(state, self_,
TensorMulConstantOp<real>(
ScalarInv<real>::to(value)))) {
if (!THC_pointwiseApply1(state, self_, TensorDivConstantOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src_);

if (!THC_pointwiseApply2(state, self_, src_,
TensorMulConstantOp<real>(
ScalarInv<real>::to(value)))) {
if (!THC_pointwiseApply2(state, self_, src_, TensorDivConstantOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
Expand Down