Skip to content

Commit fea8c9d

Browse files
committed
Add Mul and MulConstant
1 parent 8762129 commit fea8c9d

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def base_test():
3030

3131
data = Variable(torch.rand(10,10))
3232

33-
out = model_0(data) + model_1(data) - model_2(data) + model_3(data) + 1 - 2
33+
out = model_0(data) + model_1(data) * model_2(data) + 2.0 * model_3(data) + 1 - 2.0
3434

3535
out_path = 'out'
3636
if not os.path.isdir(out_path):

torch2c/emitters.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,50 @@ def free_tpl(self):
332332
register(SubConstant, torch.autograd._functions.basic_ops.SubConstant)
333333

334334

335+
class Mul(Emitter):
336+
337+
def __init__(self, obj, prevfns):
338+
Emitter.__init__(self, obj, prevfns)
339+
self.def_vars({'input0': id(prevfns[0]),
340+
'input1': id(prevfns[1])})
341+
self.infer_type_var = 'input0'
342+
343+
def call_tpl(self):
344+
return '''
345+
TH${T}Tensor *$id = TH${T}Tensor_new();
346+
TH${T}Tensor_cmul($id,$input0,$input1);
347+
'''
348+
349+
def free_tpl(self):
350+
return '''
351+
TH${T}Tensor_free($id);
352+
'''
353+
354+
register(Mul, torch.autograd._functions.basic_ops.Mul)
355+
356+
357+
class MulConstant(Emitter):
358+
359+
def __init__(self, obj, prevfns):
360+
Emitter.__init__(self, obj, prevfns)
361+
self.def_vars({'input': id(prevfns[0])})
362+
self.infer_type_var = 'input'
363+
self.def_args({'constant': obj.constant})
364+
365+
def call_tpl(self):
366+
return '''
367+
TH${T}Tensor *$id = TH${T}Tensor_new();
368+
TH${T}Tensor_mul($id,$input,$constant);
369+
'''
370+
371+
def free_tpl(self):
372+
return '''
373+
TH${T}Tensor_free($id);
374+
'''
375+
376+
register(MulConstant, torch.autograd._functions.basic_ops.MulConstant)
377+
378+
335379
class Softmax(Emitter):
336380

337381
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)