Skip to content

Commit 12e1900

Browse files
committed
Add Concat
1 parent 3a33c45 commit 12e1900

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

test/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,20 @@ def base_test():
3838
model7 = lambda x: fc3(F.max_pool2d(x.unsqueeze(dim=0),2).squeeze(dim=0))
3939
model8 = lambda x: fc3(F.max_pool3d(x.unsqueeze(0),2).squeeze())
4040
model9 = lambda x: fc3(F.max_pool1d(x.abs().view(1,1,-1),4).squeeze().view(10,10))
41+
#model10 = lambda x: fc3(x.double())
42+
#model10 = lambda x: fc3(x.view(1,10,10).select(0,0))
43+
model10 = lambda x, y: F.softmax(F.tanh(fc3(torch.cat((x,y),1))))
4144

4245
data = Variable(torch.rand(10,10))
4346
data2 = Variable(torch.rand(20,20))
47+
data1a = Variable(torch.rand(10,5))
48+
data1b = Variable(torch.rand(10,5))
4449
data3 = Variable(torch.rand(2,20,20))
4550

4651
out = model0(data) + \
4752
model1(data) * model2(data) / model3(data) / 2.0 + \
4853
2.0 * model4(data) + model5(data) + 1 - 2.0 + \
49-
model6(data2) + model7(data2) + model8(data3) + model9(data2)
54+
model6(data2) + model7(data2) + model8(data3) + model9(data2) + model10(data1a,data1b)
5055

5156
out_path = 'out'
5257
if not os.path.isdir(out_path):

torch2c/emitters.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,36 @@ def free_tpl(self):
685685
register(Clone, torch.autograd._functions.tensor.Clone)
686686

687687

688+
class Concat(Emitter):
689+
690+
def __init__(self, obj, prevfns):
691+
Emitter.__init__(self, obj, prevfns)
692+
for i, el in enumerate(prevfns):
693+
self.def_vars({'input%d' % i: id(el)})
694+
self.ninputs = len(prevfns)
695+
self.infer_type_var = 'input0'
696+
self.def_args({
697+
'dim': obj.dim,
698+
'ninputs': self.ninputs
699+
})
700+
701+
def call_tpl(self):
702+
arraytpl = '\n'.join(['inputs[%d] = $input%d;' % (i,i) for i in range(self.ninputs)])
703+
return '''
704+
TH${T}Tensor *$id = TH${T}Tensor_new();
705+
TH${T}Tensor *inputs[%d];
706+
%s
707+
TH${T}Tensor_catArray($id,inputs,$ninputs,$dim);
708+
''' % (self.ninputs, arraytpl)
709+
710+
def free_tpl(self):
711+
return '''
712+
TH${T}Tensor_free($id);
713+
'''
714+
715+
register(Concat, torch.autograd._functions.tensor.Concat)
716+
717+
688718
class Sigmoid(Emitter):
689719

690720
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)