5
5
6
6
7
7
def _wrap (obj , prevfns = []):
8
+ if obj .__class__ not in emitters ._class_map :
9
+ raise Exception ('%s does not have an Emitter' % obj .__class__ )
8
10
return emitters ._class_map [obj .__class__ ](obj ,prevfns )
9
11
10
12
11
- def _traverse_graph_recursive (out , el ):
13
+ def _traverse_graph_recursive (out , id_set , el ):
12
14
if isinstance (el , Variable ):
13
15
var = _wrap (el )
14
16
else :
15
17
prevfns = []
16
18
if hasattr (el ,'previous_functions' ):
17
19
prevfns = [f [0 ] for f in el .previous_functions ]
18
20
var = _wrap (el ,prevfns )
19
- out .append (var )
21
+ var_name = var .id_var_name ()
22
+ if var_name not in id_set :
23
+ out .append (var )
24
+ id_set .add (var_name )
20
25
if hasattr (el , 'previous_functions' ):
21
26
for u in el .previous_functions :
22
- _traverse_graph_recursive (out ,u [0 ])
27
+ _traverse_graph_recursive (out ,id_set , u [0 ])
23
28
24
29
25
30
def _traverse_graph (node ):
26
31
nodes = []
27
- _traverse_graph_recursive (nodes ,node .creator )
32
+ id_set = set ()
33
+ _traverse_graph_recursive (nodes ,id_set ,node .creator )
28
34
nodes .reverse ()
29
35
var_dict = dict ([(el .id ,el ) for el in nodes ])
30
- for el in nodes :
31
- el .infer_type (var_dict )
36
+ prev_none_count = 0
37
+ while True :
38
+ for el in nodes :
39
+ el .infer_type (var_dict )
40
+ none_count = len ([el for el in nodes if el .numtype == None ])
41
+ if none_count == 0 :
42
+ break
43
+ if none_count == prev_none_count :
44
+ raise Exception ('Cannot infer types for all nodes in the graphs' )
45
+ prev_none_count = none_count
32
46
return nodes
33
47
34
48
@@ -52,12 +66,10 @@ def _emit_c(nodes, out, fnname, out_path):
52
66
last_node = nodes [- 1 ]
53
67
ifndef = '#ifndef __%s__\n #define __%s__\n ' % (2 * (fnname .upper (),))
54
68
endif = '#endif'
55
- includes = '#include "TH.h"\n #include "THNN.h"\n #include "torch2c.h"'
69
+ includes = '#include "TH.h"\n #include "THNN.h"\n #include "torch2c.h"\n '
56
70
fndecl = 'void %s(%s)' % (fnname ,
57
71
', ' .join ([el .emit_decl () for el in var_nodes + [out_node ]]))
58
- print (fndecl )
59
72
calls = [el .emit_call (out_path ,'data' ) for el in nodes ]
60
- print ('\n ' .join (calls ))
61
73
copy_out = out_node .emit_copy (last_node .id_var_name ())
62
74
# TODO: be smarter re: frees
63
75
# analyze calls backwards and free right after last use
0 commit comments