Skip to content

Commit 7f0bb2b

Browse files
committed
Add Clone
1 parent 8f5bc4c commit 7f0bb2b

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def base_test():
3232
model1 = lambda x: F.softmax(F.elu(fc3(x)))
3333
model2 = lambda x: F.softmax(F.tanh(fc3(x)))
3434
model3 = lambda x: F.softmax(F.sigmoid(fc3(x)))
35-
model4 = lambda x: softmax(F.leaky_relu(fc4(x)))
35+
model4 = lambda x: softmax(F.leaky_relu(fc4(x))).clone()
3636
model5 = lambda x: softmax(F.logsigmoid(fc4(x.transpose(0,1))))
3737
model6 = lambda x: fc3(F.max_pool2d(x.unsqueeze(dim=0),2).squeeze())
3838
model7 = lambda x: fc3(F.max_pool2d(x.unsqueeze(dim=0),2).squeeze(dim=0))

torch2c/emitters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,28 @@ def free_tpl(self):
663663
register(Abs, torch.autograd._functions.pointwise.Abs)
664664

665665

666+
class Clone(Emitter):
667+
668+
def __init__(self, obj, prevfns):
669+
Emitter.__init__(self, obj, prevfns)
670+
self.def_vars({
671+
'input': id(prevfns[0]),
672+
})
673+
self.infer_type_var = 'input'
674+
675+
def call_tpl(self):
676+
return '''
677+
TH${T}Tensor *$id = TH${T}Tensor_newClone($input);
678+
'''
679+
680+
def free_tpl(self):
681+
return '''
682+
TH${T}Tensor_free($id);
683+
'''
684+
685+
register(Clone, torch.autograd._functions.tensor.Clone)
686+
687+
666688
class Sigmoid(Emitter):
667689

668690
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)