Skip to content

Commit b3f8bda

Browse files
committed
Add Div and DivConstant
1 parent fea8c9d commit b3f8bda

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

test/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def base_test():
3030

3131
data = Variable(torch.rand(10,10))
3232

33-
out = model_0(data) + model_1(data) * model_2(data) + 2.0 * model_3(data) + 1 - 2.0
33+
out = model_0(data) + model_1(data) * model_2(data) / model_3(data) / 2.0 + 2.0 * model_3(data) + 1 - 2.0
3434

3535
out_path = 'out'
3636
if not os.path.isdir(out_path):

torch2c/emitters.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,50 @@ def free_tpl(self):
376376
register(MulConstant, torch.autograd._functions.basic_ops.MulConstant)
377377

378378

379+
class Div(Emitter):
380+
381+
def __init__(self, obj, prevfns):
382+
Emitter.__init__(self, obj, prevfns)
383+
self.def_vars({'input0': id(prevfns[0]),
384+
'input1': id(prevfns[1])})
385+
self.infer_type_var = 'input0'
386+
387+
def call_tpl(self):
388+
return '''
389+
TH${T}Tensor *$id = TH${T}Tensor_new();
390+
TH${T}Tensor_cdiv($id,$input0,$input1);
391+
'''
392+
393+
def free_tpl(self):
394+
return '''
395+
TH${T}Tensor_free($id);
396+
'''
397+
398+
register(Div, torch.autograd._functions.basic_ops.Div)
399+
400+
401+
class DivConstant(Emitter):
402+
403+
def __init__(self, obj, prevfns):
404+
Emitter.__init__(self, obj, prevfns)
405+
self.def_vars({'input': id(prevfns[0])})
406+
self.infer_type_var = 'input'
407+
self.def_args({'constant': obj.constant})
408+
409+
def call_tpl(self):
410+
return '''
411+
TH${T}Tensor *$id = TH${T}Tensor_new();
412+
TH${T}Tensor_div($id,$input,$constant);
413+
'''
414+
415+
def free_tpl(self):
416+
return '''
417+
TH${T}Tensor_free($id);
418+
'''
419+
420+
register(DivConstant, torch.autograd._functions.basic_ops.DivConstant)
421+
422+
379423
class Softmax(Emitter):
380424

381425
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)