Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,24 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=Tensor, method={default=1}},
{name=real}})

wrap("fmod",
cname("fmod"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

wrap("remainder",
cname("remainder"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

wrap("equal",
cname("equal"),
{{name=Tensor},
{name=Tensor},
{name="boolean", creturned=true}})

for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
Expand Down Expand Up @@ -1306,6 +1324,24 @@ wrap("div",
{name=Tensor, method={default=1}},
{name=real}})

wrap("fmod",
cname("fmod"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

wrap("remainder",
cname("remainder"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

wrap("equal",
cname("equal"),
{{name=Tensor},
{name=Tensor},
{name="boolean", creturned=true}})

for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
Expand Down
125 changes: 125 additions & 0 deletions lib/THC/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCNumerics.cuh"
#include "THCTensorMathCompareT.cuh"

template <typename T>
struct TensorAddConstantOp {
Expand Down Expand Up @@ -239,6 +240,130 @@ struct TensorDivConstantOp<half> {
};
#endif // CUDA_HALF_TENSOR

template <typename T>
struct TensorRemainderOp {
TensorRemainderOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *in - val * (*in / val);
}

__device__ __forceinline__ void operator()(T* v) {
*v = *v - val * (*v / val);
}

const T val;
};

template <>
struct TensorRemainderOp<float> {
TensorRemainderOp(float v) : val(v) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = *in - val * floorf(*in / val);
}

__device__ __forceinline__ void operator()(float* v) {
*v = *v - val * floorf(*v / val);
}

const float val;
};

template <>
struct TensorRemainderOp<double> {
TensorRemainderOp(double v) : val(v) {}
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = *in - val * floor(*in / val);
}

__device__ __forceinline__ void operator()(double* v) {
*v = *v - val * floor(*v / val);
}

const double val;
};

#ifdef CUDA_HALF_TENSOR
template <>
struct TensorRemainderOp<half> {
#ifdef CUDA_HALF_INSTRUCTIONS
TensorRemainderOp(half v) : val(v) {}
#else
TensorRemainderOp(half v): fval(THC_half2float(v)) {}
#endif

__device__ __forceinline__ void operator()(half* out, half* in) {
#ifdef CUDA_HALF_INSTRUCTIONS
*out = __hsub(*in, __hmul(val, hfloor(__hdiv(*in, val))));
#else
float fin = __half2float(*in);
float fout = fin - fval * floorf(fin / fval);
*out = __float2half(fout);
#endif
}

__device__ __forceinline__ void operator()(half* v) {
#ifdef CUDA_HALF_INSTRUCTIONS
*v = __hsub(*v, __hmul(val, hfloor(__hdiv(*v, val))));
#else
float fv = __half2float(*v);
fv = fv - fval * floorf(fv / fval);
*v = __float2half(fv);
#endif
}

#ifdef CUDA_HALF_INSTRUCTIONS
const half val;
#else
const float fval;
#endif
};
#endif // CUDA_HALF_TENSOR

template <typename T>
struct TensorFmodOp {
TensorFmodOp(T v) : val((float)v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = (T) fmodf((float) *in, val);
}

__device__ __forceinline__ void operator()(T* v) {
*v = (T) fmodf((float) *v, val);
}

const float val;
};

template <>
struct TensorFmodOp<double> {
TensorFmodOp(double v) : val(v) {}
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = fmod(*in, val);
}

__device__ __forceinline__ void operator()(double* v) {
*v = fmod(*v, val);
}

const double val;
};

#ifdef CUDA_HALF_TENSOR
template <>
struct TensorFmodOp<half> {
TensorFmodOp(half v): fval(THC_half2float(v)) {}

__device__ __forceinline__ void operator()(half* out, half* in) {
*out = __float2half(fmodf(__half2float(*in), fval));
}

__device__ __forceinline__ void operator()(half* v) {
*v = __float2half(fmodf(__half2float(*v), fval));
}

const float fval;
};
#endif // CUDA_HALF_TENSOR

template <typename T, int Upper>
struct TensorTriOp {
TensorTriOp(T *start_, long stride0_, long stride1_, long k_)
Expand Down
62 changes: 62 additions & 0 deletions lib/THC/generic/THCTensorMathPairwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,44 @@ THCTensor_(div)(THCState* state, THCTensor *self_, THCTensor *src_, real value)
THCudaCheck(cudaGetLastError());
}

THC_API void
THCTensor_(fmod)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
{
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
if (self_ == src_) {
if (!THC_pointwiseApply1(state, self_, TensorFmodOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src_);

if (!THC_pointwiseApply2(state, self_, src_, TensorFmodOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}

THCudaCheck(cudaGetLastError());
}

THC_API void
THCTensor_(remainder)(THCState *state, THCTensor *self_, THCTensor *src_, real value)
{
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
if (self_ == src_) {
if (!THC_pointwiseApply1(state, self_, TensorRemainderOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCTensor_(resizeAs)(state, self_, src_);

if (!THC_pointwiseApply2(state, self_, src_, TensorRemainderOp<real>(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}

THCudaCheck(cudaGetLastError());
}

void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, long k)
{
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
Expand Down Expand Up @@ -146,4 +184,28 @@ void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, long k
THCudaCheck(cudaGetLastError());
}

THC_API int THCTensor_(equal)(THCState *state, THCTensor *self_, THCTensor *src_)
{
THAssert(THCTensor_(checkGPU)(state, 2, self_, src_));
if (!THCTensor_(isSameSizeAs(state, self_, src_))) {
return 0;
}

// This is not as efficient as TH, but the basic idea: create a buffer that stores
// 1 if the two tensors are equal at a position, otherwise 0. If the minimum value
// in this buffer is 1, the two tensors are equal, otherwise they are not

THLongStorage *size = THCTensor_(newSizeOf)(state, self_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i merged this a tad bit prematurely, but ugh this is a memory leak.
you need to free size after using it
THLongStorage_free(size)

THCudaByteTensor *buf = THCudaByteTensor_newWithSize(state, size, NULL);

if (!THC_pointwiseApply3(state, buf, self_, src_, TensorEQOp<real, unsigned char>())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}

unsigned char min = THCudaByteTensor_minall(state, buf);
THCudaByteTensor_free(state, buf);

return min != 0;
}

#endif
4 changes: 4 additions & 0 deletions lib/THC/generic/THCTensorMathPairwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@ THC_API void THCTensor_(add)(THCState *state, THCTensor *self, THCTensor *src, r
THC_API void THCTensor_(sub)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(mul)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(div)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(fmod)(THCState *state, THCTensor *self, THCTensor *src, real value);
THC_API void THCTensor_(remainder)(THCState *state, THCTensor *self, THCTensor *src, real value);

THC_API int THCTensor_(equal)(THCState *state, THCTensor *self, THCTensor *src);

#endif
75 changes: 75 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,81 @@ function test.addcdiv()
checkMultiDevice(r, 'addcdiv', x, torch.uniform(), y, z)
end

function test.fmod()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():randn(sz1, sz2)
x:apply(function(x)
x = x * torch.random(1, 100)
return x
end)
local r = torch.normal(0, 25)
print(x, r)

for _, typename in ipairs(typenames) do
local x = x:type(t2cpu[typename])
compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'fmod', r)
end
end

function test.remainder()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor():randn(sz1, sz2)
x:apply(function(x)
x = x * torch.random(1, 100)
return x
end)
local r = torch.normal(0, 25)
print(x, r)

for _, typename in ipairs(typenames) do
local x = x:type(t2cpu[typename])
compareCPUAndCUDATypeTensorArgs(typename, nil, x, 'remainder', r)
end
end

function test.equal()
-- empty tensors are equal
local x = torch.FloatTensor()
local y = torch.FloatTensor()

for _, typename in ipairs(typenames) do
local a = x:type(typename)
local b = y:type(typename)
tester:assert(a:equal(b), 'Empty Tensors should be considered equal')
end

-- mismatched size tensors are not equal
local x = torch.FloatTensor(5):fill(1)
local y = torch.FloatTensor(3):fill(1)

for _, typename in ipairs(typenames) do
local a = x:type(typename)
local b = y:type(typename)
tester:assert(not a:equal(b), 'Tensors of different sizes not equal')
end

-- tensors of same size but different value are not equal
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
local x = torch.FloatTensor(sz1, sz2):apply(function() return torch.random(0, 255) end)
local y = torch.add(x, 1)

for _, typename in ipairs(typenames) do
local a = x:type(typename)
local b = y:type(typename)
tester:assert(not a:equal(b), 'Tensors should not be equal')
end

-- actual equality
for _, typename in ipairs(typenames) do
local a = x:type(typename)
local b = x:type(typename)
tester:assert(a:equal(b), 'Tensors should be equal')
end
end

function test.logicalValue()
local sz1 = chooseInt(minsize, maxsize)
local sz2 = chooseInt(minsize, maxsize)
Expand Down