Skip to content

Commit 73e9467

Browse files
committed
Add Softmax and Add. Fix infer and don't duplicate Variables for diverging paths.
1 parent cf2ec7f commit 73e9467

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

torch2c/__init__.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,44 @@
55

66

77
def _wrap(obj, prevfns=[]):
8+
if obj.__class__ not in emitters._class_map:
9+
raise Exception('%s does not have an Emitter' % obj.__class__)
810
return emitters._class_map[obj.__class__](obj,prevfns)
911

1012

11-
def _traverse_graph_recursive(out, el):
13+
def _traverse_graph_recursive(out, id_set, el):
1214
if isinstance(el, Variable):
1315
var = _wrap(el)
1416
else:
1517
prevfns = []
1618
if hasattr(el,'previous_functions'):
1719
prevfns = [f[0] for f in el.previous_functions]
1820
var = _wrap(el,prevfns)
19-
out.append(var)
21+
var_name = var.id_var_name()
22+
if var_name not in id_set:
23+
out.append(var)
24+
id_set.add(var_name)
2025
if hasattr(el, 'previous_functions'):
2126
for u in el.previous_functions:
22-
_traverse_graph_recursive(out,u[0])
27+
_traverse_graph_recursive(out,id_set,u[0])
2328

2429

2530
def _traverse_graph(node):
2631
nodes = []
27-
_traverse_graph_recursive(nodes,node.creator)
32+
id_set = set()
33+
_traverse_graph_recursive(nodes,id_set,node.creator)
2834
nodes.reverse()
2935
var_dict = dict([(el.id,el) for el in nodes])
30-
for el in nodes:
31-
el.infer_type(var_dict)
36+
prev_none_count = 0
37+
while True:
38+
for el in nodes:
39+
el.infer_type(var_dict)
40+
none_count = len([el for el in nodes if el.numtype == None])
41+
if none_count == 0:
42+
break
43+
if none_count == prev_none_count:
44+
raise Exception('Cannot infer types for all nodes in the graphs')
45+
prev_none_count = none_count
3246
return nodes
3347

3448

@@ -52,12 +66,10 @@ def _emit_c(nodes, out, fnname, out_path):
5266
last_node = nodes[-1]
5367
ifndef = '#ifndef __%s__\n#define __%s__\n' % (2*(fnname.upper(),))
5468
endif = '#endif'
55-
includes = '#include "TH.h"\n#include "THNN.h"\n#include "torch2c.h"'
69+
includes = '#include "TH.h"\n#include "THNN.h"\n#include "torch2c.h"\n'
5670
fndecl = 'void %s(%s)' % (fnname,
5771
', '.join([el.emit_decl() for el in var_nodes + [out_node]]))
58-
print(fndecl)
5972
calls = [el.emit_call(out_path,'data') for el in nodes]
60-
print('\n'.join(calls))
6173
copy_out = out_node.emit_copy(last_node.id_var_name())
6274
# TODO: be smarter re: frees
6375
# analyze calls backwards and free right after last use

torch2c/emitters.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def id_var_name(self):
5454
def var_names(self):
5555
return {k: self.var_name(v) for k,v in self.vars.items()}
5656

57+
def persisted_vars(self):
58+
return []
59+
5760
def call_tpl(self):
5861
return ''
5962

@@ -241,6 +244,51 @@ def free_tpl(self):
241244
register(Linear, torch.nn._functions.linear.Linear)
242245

243246

247+
class Add(Emitter):
248+
249+
def __init__(self, obj, prevfns):
250+
Emitter.__init__(self, obj, prevfns)
251+
self.def_vars({'input0': id(prevfns[0]),
252+
'input1': id(prevfns[1])})
253+
self.infer_type_var = 'input0'
254+
255+
def call_tpl(self):
256+
return '''
257+
TH${T}Tensor *$id = TH${T}Tensor_new();
258+
TH${T}Tensor_cadd($id,$input0,1.0,$input1);
259+
'''
260+
261+
def free_tpl(self):
262+
return '''
263+
TH${T}Tensor_free($id);
264+
'''
265+
266+
register(Add, torch.autograd._functions.basic_ops.Add)
267+
268+
269+
270+
271+
class Softmax(Emitter):
272+
273+
def __init__(self, obj, prevfns):
274+
Emitter.__init__(self, obj, prevfns)
275+
self.def_vars({'input': id(prevfns[0])})
276+
self.infer_type_var = 'input'
277+
278+
def call_tpl(self):
279+
return '''
280+
TH${T}Tensor *$id = TH${T}Tensor_new();
281+
THNN_${T}SoftMax_updateOutput(NULL,$input,$id);
282+
'''
283+
284+
def free_tpl(self):
285+
return '''
286+
TH${T}Tensor_free($id);
287+
'''
288+
289+
register(Softmax, torch.nn._functions.thnn.auto.Softmax)
290+
291+
244292
class LogSoftmax(Emitter):
245293

246294
def __init__(self, obj, prevfns):

0 commit comments

Comments
 (0)