Skip to content

Commit 84ce906

Browse files
committed
Add ELU
1 parent 813eb3b commit 84ce906

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def base_test():
2424
fc3.weight.data.normal_(0.0,1.0)
2525
fc3.bias.data.normal_(0.0,1.0)
2626

27-
model_1 = lambda x: F.softmax(F.relu(fc3(x)))
27+
model_1 = lambda x: F.softmax(F.elu(fc3(x)))
2828

2929
data = Variable(torch.rand(10,10))
3030

torch2c/emitters.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ class AddConstant(Emitter):
270270

271271
def __init__(self, obj, prevfns):
272272
Emitter.__init__(self, obj, prevfns)
273-
print(obj.constant)
274273
self.def_vars({'input': id(prevfns[0])})
275274
self.infer_type_var = 'input'
276275
self.def_args({'constant': obj.constant})
@@ -360,6 +359,33 @@ def free_tpl(self):
360359
register(Threshold, torch.nn._functions.thnn.auto.Threshold)
361360

362361

362+
class ELU(Emitter):
363+
364+
def __init__(self, obj, prevfns):
365+
Emitter.__init__(self, obj, prevfns)
366+
self.def_vars({
367+
'input': id(prevfns[0]),
368+
})
369+
self.infer_type_var = 'input'
370+
self.def_args({
371+
'alpha': obj.additional_args[0],
372+
'inplace': int(obj.additional_args[1])
373+
})
374+
375+
def call_tpl(self):
376+
return '''
377+
TH${T}Tensor *$id = TH${T}Tensor_new();
378+
THNN_${T}ELU_updateOutput(NULL,$input,$id,$alpha,$inplace);
379+
'''
380+
381+
def free_tpl(self):
382+
return '''
383+
TH${T}Tensor_free($id);
384+
'''
385+
386+
register(ELU, torch.nn._functions.thnn.auto.ELU)
387+
388+
363389
class Noop(Emitter):
364390

365391
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)