Skip to content

Commit 71f25ba

Browse files
committed
Add Tanh
1 parent e33558a commit 71f25ba

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def base_test():
2525
fc3.bias.data.normal_(0.0,1.0)
2626

2727
model_1 = lambda x: F.softmax(F.elu(fc3(x)))
28-
model_2 = lambda x: F.softmax(F.elu(fc3(x)))
28+
model_2 = lambda x: F.softmax(F.tanh(fc3(x)))
2929

3030
data = Variable(torch.rand(10,10))
3131

torch2c/emitters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,29 @@ def free_tpl(self):
429429
register(ELU, torch.nn._functions.thnn.auto.ELU)
430430

431431

432+
class Tanh(Emitter):
433+
434+
def __init__(self, obj, prevfns):
435+
Emitter.__init__(self, obj, prevfns)
436+
self.def_vars({
437+
'input': id(prevfns[0]),
438+
})
439+
self.infer_type_var = 'input'
440+
441+
def call_tpl(self):
442+
return '''
443+
TH${T}Tensor *$id = TH${T}Tensor_new();
444+
THNN_${T}Tanh_updateOutput(NULL,$input,$id);
445+
'''
446+
447+
def free_tpl(self):
448+
return '''
449+
TH${T}Tensor_free($id);
450+
'''
451+
452+
register(Tanh, torch.autograd._functions.pointwise.Tanh)
453+
454+
432455
class Noop(Emitter):
433456

434457
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)