Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 264 additions & 24 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ local unpack = unpack or table.unpack

-- specific to CUDA
local typenames = {'CudaByteTensor',
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaTensor',
'CudaDoubleTensor',
'CudaHalfTensor'
'CudaCharTensor',
'CudaShortTensor',
'CudaIntTensor',
'CudaLongTensor',
'CudaTensor',
'CudaDoubleTensor',
'CudaHalfTensor'
}

for _, typename in ipairs(typenames) do
Expand Down Expand Up @@ -430,12 +430,21 @@ local handledTypereals = {'unsigned char',
'double',
'half'
}
local handledTypeaccreals = {'long',
'long',
'long',
'long',
'long',
'double',
'float'
}

for k, Tensor in pairs(handledTypenames) do
if Tensor == 'CudaHalfTensor' then
interface:print("#ifdef CUDA_HALF_TENSOR")
end
local real = handledTypereals[k]
local accreal = handledTypeaccreals[k]

function interface.luaname2wrapname(self, name)
return string.format('cutorch_%s_%s', Tensor, name)
Expand Down Expand Up @@ -515,6 +524,18 @@ for k, Tensor in pairs(handledTypenames) do
{name=real, default=1},
{name=Tensor}})

wrap("mul",
cname("mul"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

wrap("div",
cname("div"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

for _, name in ipairs({"cmul", "cpow", "cdiv"}) do
wrap(name,
cname(name),
Expand All @@ -523,6 +544,224 @@ for k, Tensor in pairs(handledTypenames) do
{name=Tensor}})
end

for _,name in ipairs({"min", "max"}) do
wrap(name,
cname(name .. "all"),
{{name=Tensor},
{name=real, creturned=true}},
cname(name),
{{name=Tensor, default=true, returned=true},
{name='CudaLongTensor', default=true, returned=true},
{name=Tensor},
{name="index"}})
end

if Tensor == 'CudaByteTensor' then
for _,name in pairs({'all', 'any'}) do
wrap(name,
cname('logical' .. name),
{{name=Tensor},
{name="boolean", creturned=true}})
end
end

for _,name in pairs({'lt','gt','le','ge','eq','ne'}) do
wrap(name,
cname(name .. 'Value'),
{{name='CudaByteTensor',default=true, returned=true},
{name=Tensor},
{name=real}},
cname(name .. 'ValueT'),
{{name=Tensor, returned=true},
{name=Tensor},
{name=real}},
cname(name .. 'Tensor'),
{{name='CudaByteTensor',default=true, returned=true},
{name=Tensor},
{name=Tensor}},
cname(name .. 'TensorT'),
{{name=Tensor, returned=true},
{name=Tensor},
{name=Tensor}})
end

wrap("sum",
cname("sumall"),
{{name=Tensor},
{name=accreal, creturned=true}},
cname("sum"),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name="index"}})

wrap("prod",
cname("prodall"),
{{name=Tensor},
{name=accreal, creturned=true}},
cname("prod"),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name="index"}})

wrap("maskedFill",
cname("maskedFill"),
{{name=Tensor, returned=true, method={default='nil'}},
{name='CudaByteTensor'},
{name=real}})

wrap("maskedCopy",
cname("maskedCopy"),
{{name=Tensor, returned=true, method={default='nil'}},
{name='CudaByteTensor'},
{name=Tensor}})

wrap("maskedSelect",
cname("maskedSelect"),
{{name=Tensor, returned=true, default=true},
{name=Tensor},
{name='CudaByteTensor'}})

-- BLAS functions
if real == 'float' or real == 'double' or real == 'half' then
wrap("mv",
cname("addmv"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 1, %s);",
Tensor, arg.args[5]:carg()),
string.format("TH%s_resize1d(cutorch_getstate(L), %s, %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg())
}, '\n')
end,
precall=function(arg)
return table.concat(
{
string.format("TH%s_zero(cutorch_getstate(L), %s);", Tensor, arg:carg()),
arg.__metatable.precall(arg)
}, '\n')
end
},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=2},
{name=Tensor, dim=1}}
)

wrap("mm",
cname("addmm"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 2, %s, %s);",
Tensor, arg.args[5]:carg(), arg.args[6]:carg()),
string.format("TH%s_resize2d(cutorch_getstate(L), %s, %s->size[0], %s->size[1]);",
Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
},
{name=real, default=0, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=2},
{name=Tensor, dim=2}}
)

wrap("bmm",
cname("baddbmm"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 2, %s, %s);",
Tensor, arg.args[5]:carg(), arg.args[6]:carg()),
string.format("TH%s_resize3d(cutorch_getstate(L), %s, %s->size[0], %s->size[1], %s->size[2]);",
Tensor, arg:carg(), arg.args[5]:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
},
{name=real, default=0, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=3},
{name=Tensor, dim=3}}
)

wrap("ger",
cname("addr"),
{{name=Tensor, default=true, returned=true, method={default='nil'},
init=function(arg)
return table.concat(
{
arg.__metatable.init(arg),
string.format("TH%s_checkGPU(cutorch_getstate(L), 2, %s, %s);",
Tensor, arg.args[5]:carg(), arg.args[6]:carg()),
string.format("TH%s_resize2d(cutorch_getstate(L), %s, %s->size[0], %s->size[0]);", Tensor, arg:carg(), arg.args[5]:carg(), arg.args[6]:carg())
}, '\n')
end,
precall=function(arg)
return table.concat(
{
string.format("TH%s_zero(cutorch_getstate(L), %s);", Tensor, arg:carg()),
arg.__metatable.precall(arg)
}, '\n')
end
},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, invisible=true},
{name=real, default=1, invisible=true},
{name=Tensor, dim=1},
{name=Tensor, dim=1}}
)

for _,f in ipairs({
{name="addmv", dim1=1, dim2=2, dim3=1},
{name="addmm", dim1=2, dim2=2, dim3=2},
{name="addr", dim1=2, dim2=1, dim3=1},
{name="baddbmm", dim1=3, dim2=3, dim3=3},
{name="addbmm", dim1=2, dim2=3, dim3=3},
}
) do

interface:wrap(f.name,
cname(f.name),
{{name=Tensor, default=true, returned=true},
{name=real, default=1},
{name=Tensor, dim=f.dim1},
{name=real, default=1},
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}})

-- there is an ambiguity here, hence the more complicated setup
method:wrap(f.name,
cname(f.name),
{{name=Tensor, returned=true, dim=f.dim1},
{name=real, default=1, invisible=true},
{name=Tensor, default=1, dim=f.dim1},
{name=real, default=1},
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}},
cname(f.name),
{{name=Tensor, returned=true, dim=f.dim1},
{name=real},
{name=Tensor, default=1, dim=f.dim1},
{name=real},
{name=Tensor, dim=f.dim2},
{name=Tensor, dim=f.dim3}})
end
end

wrap("dot",
cname("dot"),
{{name=Tensor},
{name=Tensor},
{name=real, creturned=true}})

method:register("m_cutorch_" .. Tensor .. "Math__")
interface:print(method:tostring())
method:clearhistory()
Expand Down Expand Up @@ -677,20 +916,20 @@ wrap("addcdiv",
wrap("maskedFill",
cname("maskedFill"),
{{name=Tensor, returned=true, method={default='nil'}},
{name=Tensor},
{name='CudaByteTensor'},
{name=real}})

wrap("maskedCopy",
cname("maskedCopy"),
{{name=Tensor, returned=true, method={default='nil'}},
{name=Tensor},
{name='CudaByteTensor'},
{name=Tensor}})

wrap("maskedSelect",
cname("maskedSelect"),
{{name=Tensor, returned=true, default=true},
{name=Tensor},
{name=Tensor}})
{name='CudaByteTensor'}})

wrap("gather",
cname("gather"),
Expand Down Expand Up @@ -914,7 +1153,7 @@ for _,name in ipairs({"min", "max"}) do
{name=real, creturned=true}},
cname(name),
{{name=Tensor, default=true, returned=true},
{name=Tensor, default=true, returned=true},
{name='CudaLongTensor', default=true, returned=true},
{name=Tensor},
{name="index"}})
end
Expand Down Expand Up @@ -1018,20 +1257,21 @@ wrap("clamp",
for _,name in pairs({'lt','gt','le','ge','eq','ne'}) do
wrap(name,
cname(name .. 'Value'),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name=real}},
{{name='CudaByteTensor',default=true, returned=true},
{name=Tensor},
{name=real}},
cname(name .. 'ValueT'),
{{name=Tensor, returned=true},
{name=Tensor},
{name=real}},
cname(name .. 'Tensor'),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name=Tensor}})
end

for _,name in pairs({'all', 'any'}) do
wrap(name,
cname('logical' .. name),
{{name=Tensor},
{name="boolean", creturned=true}})
{{name='CudaByteTensor',default=true, returned=true},
{name=Tensor},
{name=Tensor}},
cname(name .. 'TensorT'),
{{name=Tensor, returned=true},
{name=Tensor},
{name=Tensor}})
end

wrap("cat",
Expand Down
Loading