Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 151 additions & 51 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ static int torch_isnonemptytable(lua_State *L, int idx)
local unpack = unpack or table.unpack

-- specific to CUDA
local typenames = {'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaTensor',
'CudaDoubleTensor',
'CudaHalfTensor'
local typenames = {
'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaTensor',
'CudaDoubleTensor',
'CudaHalfTensor'
}

for _, typename in ipairs(typenames) do
Expand Down Expand Up @@ -409,37 +410,95 @@ local function wrap(...)
method:wrap(unpack(args))
end

local Tensor

-- functions to help take in arguments that are Tensor or CudaLongTensor (for backward compatibility)
-- used in scatter / gather for example
local function TensorToCudaLong_declare(dummy)
return function(arg)
local txt = {}
table.insert(txt, string.format("THCudaLongTensor *arg%d = NULL;", arg.i))
if dummy then
table.insert(txt, string.format("THCudaLongTensor *indexLongTensor = NULL;"))
table.insert(txt, string.format("TH%s *dummyIndexTensor = NULL;", Tensor))
end
return table.concat(txt, '\n')
end
end
local function TensorToCudaLong_check(arg, idx)
return string.format('(dummyIndexTensor = luaT_toudata(L, %d, "torch.%s"))', idx, Tensor)
end
local function TensorToCudaLong_read(arg, idx)
local copyname = Tensor:match("(%a+)Tensor")
if copyname == 'Cuda' then
copyname = 'CudaFloat'
end
local txt = {}
table.insert(txt, string.format('arg%d = THCudaLongTensor_new(default_arg1);', arg.i))
table.insert(txt, string.format('THLongStorage *indexSize = TH%s_newSizeOf(default_arg1, dummyIndexTensor);', Tensor))
table.insert(txt, string.format('THCudaLongTensor_resize(default_arg1, arg%d, indexSize, NULL);', arg.i))
table.insert(txt, string.format('THLongStorage_free(indexSize);'))
table.insert(txt, string.format('THCudaLongTensor_copy%s(default_arg1, arg%d, dummyIndexTensor);', copyname, arg.i))
table.insert(txt, string.format('indexLongTensor = arg%d;', arg.i))
return table.concat(txt, '\n')
end

local function TensorToCudaLong_postcall(arg)
return "if (indexLongTensor != NULL) THCudaLongTensor_free(default_arg1, indexLongTensor);\n"
end

-- function to initialize the gather call
local function gatherInit(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 1, %s);",
Tensor, arg.args[4]:carg()),
string.format(
[[
THCState *state = cutorch_getstate(L);
THLongStorage *indicesSize = THCudaLongTensor_newSizeOf(state, %s);
TH%s_resize(state, %s, indicesSize, NULL);
THLongStorage_free(indicesSize);
]], arg.args[4]:carg(), Tensor, arg:carg()),
}, '\n')
end

--
-- Non-CudaTensor type math, since these are less fully implemented than
-- CudaTensor
--

local handledTypenames = {'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaDoubleTensor',
'CudaHalfTensor',
local handledTypenames = {
'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaDoubleTensor',
'CudaHalfTensor',
}
local handledTypereals = {'unsigned char',
'char',
'short',
'int',
'long',
'double',
'half'
local handledTypereals = {
'unsigned char',
'char',
'short',
'int',
'long',
'double',
'half'
}
local handledTypeaccreals = {'long',
'long',
'long',
'long',
'long',
'double',
'float'
local handledTypeaccreals = {
'long',
'long',
'long',
'long',
'long',
'double',
'float'
}

for k, Tensor in pairs(handledTypenames) do
for k, Tensor_ in pairs(handledTypenames) do
Tensor = Tensor_
if Tensor == 'CudaHalfTensor' then
interface:print("#ifdef CUDA_HALF_TENSOR")
end
Expand Down Expand Up @@ -621,6 +680,41 @@ for k, Tensor in pairs(handledTypenames) do
{name=Tensor},
{name='CudaByteTensor'}})

wrap("gather",
cname("gather"),
{{name=Tensor, default=true, returned=true, init=gatherInit},
{name=Tensor},
{name="index"},
{name='CudaLongTensor'}},
cname("gather"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, default=true, returned=true, init=gatherInit},
{name=Tensor},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(true), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall}})

wrap("scatter",
cname("scatter"),
{{name=Tensor, returned=true},
{name="index"},
{name='CudaLongTensor'},
{name=Tensor}},
cname("scatter"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(true), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall},
{name=Tensor}},
cname("scatterFill"),
{{name=Tensor, returned=true},
{name="index"},
{name='CudaLongTensor'},
{name=real}},
cname("scatterFill"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(false), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall},
{name=real}}
)

-- BLAS functions
if real == 'float' or real == 'double' or real == 'half' then
wrap("mv",
Expand Down Expand Up @@ -795,7 +889,7 @@ end
-- CudaTensor special handling, since it is more fully implemented
--

local Tensor = "CudaTensor"
Tensor = "CudaTensor"
local real = "float"

function interface.luaname2wrapname(self, name)
Expand Down Expand Up @@ -933,32 +1027,38 @@ wrap("maskedSelect",

wrap("gather",
cname("gather"),
{{name=Tensor, default=true, returned=true,
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 1, %s);",
Tensor, arg.args[4]:carg()),
string.format("TH%s_resizeAs(cutorch_getstate(L), %s, %s);", Tensor, arg:carg(), arg.args[4]:carg()),
}, '\n')
end
},
{name=Tensor},
{name="index"},
{name=Tensor}})
{{name=Tensor, default=true, returned=true, init=gatherInit},
{name=Tensor},
{name="index"},
{name='CudaLongTensor'}},
cname("gather"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, default=true, returned=true, init=gatherInit},
{name=Tensor},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(true), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall}})

wrap("scatter",
cname("scatter"),
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor},
{name=Tensor}},
{name="index"},
{name='CudaLongTensor'},
{name=Tensor}},
cname("scatter"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(true), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall},
{name=Tensor}},
cname("scatterFill"),
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor},
{name=real}})
{name="index"},
{name='CudaLongTensor'},
{name=real}},
cname("scatterFill"), -- this is for backward-compatibility, and takes in "Tensor" as the indexing tensor
{{name=Tensor, returned=true},
{name="index"},
{name=Tensor, declare=TensorToCudaLong_declare(false), check=TensorToCudaLong_check, read=TensorToCudaLong_read, postcall=TensorToCudaLong_postcall},
{name=real}}
)

wrap("sort",
cname("sort"),
Expand Down
7 changes: 3 additions & 4 deletions lib/THC/THCTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "generic/THCTensorMasked.h"
#include "THCGenerateAllTypes.h"

#include "generic/THCTensorScatterGather.h"
#include "THCGenerateAllTypes.h"

THC_API void THCudaTensor_tril(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
THC_API void THCudaTensor_triu(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
THC_API void THCudaTensor_diag(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
Expand Down Expand Up @@ -115,10 +118,6 @@ THC_API void THCudaTensor_indexAdd_long(THCState *state, THCudaTensor *res_, int
THC_API void THCudaTensor_indexFill_long(THCState *state, THCudaTensor *tensor, int dim, THLongTensor *index, float val);
THC_API void THCudaTensor_indexSelect_long(THCState *state, THCudaTensor *tensor, THCudaTensor *src, int dim, THLongTensor *index);

THC_API void THCudaTensor_gather(THCState* state, THCudaTensor *tensor, THCudaTensor *src, int dim, THCudaTensor *index);
THC_API void THCudaTensor_scatter(THCState* state, THCudaTensor *tensor, int dim, THCudaTensor *index, THCudaTensor *src);
THC_API void THCudaTensor_scatterFill(THCState* state, THCudaTensor *tensor, int dim, THCudaTensor *index, float value);

THC_API int THCudaByteTensor_logicalall(THCState *state, THCudaByteTensor *self);
THC_API int THCudaByteTensor_logicalany(THCState *state, THCudaByteTensor *self);

Expand Down
Loading