Skip to content

Commit 8f5bc4c

Browse files
committed
Add Abs
1 parent 886c639 commit 8f5bc4c

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def base_test():
3737
model6 = lambda x: fc3(F.max_pool2d(x.unsqueeze(dim=0),2).squeeze())
3838
model7 = lambda x: fc3(F.max_pool2d(x.unsqueeze(dim=0),2).squeeze(dim=0))
3939
model8 = lambda x: fc3(F.max_pool3d(x.unsqueeze(0),2).squeeze())
40-
model9 = lambda x: fc3(F.max_pool1d(x.view(1,1,-1),4).squeeze().view(10,10))
40+
model9 = lambda x: fc3(F.max_pool1d(x.abs().view(1,1,-1),4).squeeze().view(10,10))
4141

4242
data = Variable(torch.rand(10,10))
4343
data2 = Variable(torch.rand(20,20))

torch2c/emitters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,29 @@ def free_tpl(self):
640640
register(Tanh, torch.autograd._functions.pointwise.Tanh)
641641

642642

643+
class Abs(Emitter):
644+
645+
def __init__(self, obj, prevfns):
646+
Emitter.__init__(self, obj, prevfns)
647+
self.def_vars({
648+
'input': id(prevfns[0]),
649+
})
650+
self.infer_type_var = 'input'
651+
652+
def call_tpl(self):
653+
return '''
654+
TH${T}Tensor *$id = TH${T}Tensor_new();
655+
THNN_${T}Abs_updateOutput(NULL,$input,$id);
656+
'''
657+
658+
def free_tpl(self):
659+
return '''
660+
TH${T}Tensor_free($id);
661+
'''
662+
663+
register(Abs, torch.autograd._functions.pointwise.Abs)
664+
665+
643666
class Sigmoid(Emitter):
644667

645668
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)