7
7
def _wrap (obj , prevfns = []):
8
8
if obj .__class__ not in emitters ._class_map :
9
9
raise Exception ('%s does not have an Emitter' % obj .__class__ )
10
- return emitters ._class_map [obj .__class__ ](obj ,prevfns )
10
+ return emitters ._class_map [obj .__class__ ](obj , prevfns )
11
11
12
12
13
13
def _traverse_graph_recursive (out , el ):
14
14
if isinstance (el , Variable ):
15
15
var = _wrap (el )
16
16
else :
17
17
prevfns = []
18
- if hasattr (el ,'previous_functions' ):
18
+ if hasattr (el , 'previous_functions' ):
19
19
prevfns = [f [0 ] for f in el .previous_functions ]
20
- var = _wrap (el ,prevfns )
20
+ var = _wrap (el , prevfns )
21
21
out .append (var )
22
22
if hasattr (el , 'previous_functions' ):
23
23
for u in el .previous_functions :
24
- _traverse_graph_recursive (out ,u [0 ])
24
+ _traverse_graph_recursive (out , u [0 ])
25
25
26
26
27
27
def _dedup_nodes (nodes ):
@@ -37,10 +37,10 @@ def _dedup_nodes(nodes):
37
37
38
38
def _traverse_graph (node ):
39
39
nodes = []
40
- _traverse_graph_recursive (nodes ,node .creator )
40
+ _traverse_graph_recursive (nodes , node .creator )
41
41
nodes .reverse ()
42
42
nodes = _dedup_nodes (nodes )
43
- var_dict = dict ([(el .id ,el ) for el in nodes ])
43
+ var_dict = dict ([(el .id , el ) for el in nodes ])
44
44
prev_none_count = 0
45
45
while True :
46
46
for el in nodes :
@@ -69,28 +69,29 @@ def _emit_c(nodes, out, fnname, out_path):
69
69
# 3. free parameters (gen name from fnname)
70
70
# to avoid loading from disk at each forward
71
71
var_nodes = [el for el in nodes if type (el ) == emitters .Variable ]
72
- out_node = _wrap_out_node (nodes ,out )
72
+ out_node = _wrap_out_node (nodes , out )
73
73
# TODO: make it more general re: last_node?
74
74
last_node = nodes [- 1 ]
75
- ifndef = '#ifndef __%s__\n #define __%s__\n ' % (2 * (fnname .upper (),))
75
+ ifndef = '#ifndef __%s__\n #define __%s__\n ' % (2 * (fnname .upper (),))
76
76
endif = '#endif'
77
77
includes = '#include "TH.h"\n #include "THNN.h"\n #include "torch2c.h"\n '
78
- fndecl = 'void %s(%s)' % (fnname ,
79
- ', ' .join ([el .emit_decl () for el in var_nodes + [out_node ]]))
80
- calls = [el .emit_call (out_path ,'data' ) for el in nodes ]
78
+ fndecl = 'void %s(%s)' % (fnname ,
79
+ ', ' .join ([el .emit_decl () for el in var_nodes + [out_node ]]))
80
+ calls = [el .emit_call (out_path , 'data' ) for el in nodes ]
81
81
copy_out = out_node .emit_copy (last_node .id_var_name ())
82
82
# TODO: be smarter re: frees
83
83
# analyze calls backwards and free right after last use
84
84
frees = [el .emit_free () for el in nodes ]
85
85
frees .reverse ()
86
86
indent = ' ' * 2
87
- lines = [indent + el for el in '\n ' .join (calls + [copy_out ] + frees ).split ('\n ' ) if el ]
87
+ lines = [
88
+ indent + el for el in '\n ' .join (calls + [copy_out ] + frees ).split ('\n ' ) if el ]
88
89
lines = [ifndef , includes , fndecl , '{' ] + lines + ['}' , endif ]
89
90
return '\n ' .join (lines )
90
91
91
92
92
93
def _to_persisted (var_node ):
93
- persisted = emitters .PersistedVariable (var_node .obj ,[])
94
+ persisted = emitters .PersistedVariable (var_node .obj , [])
94
95
persisted .numtype = var_node .numtype
95
96
return persisted
96
97
@@ -104,41 +105,45 @@ def _clone_var(var):
104
105
105
106
106
107
def _emit_test (nodes , out , fnname , filename , out_path ):
107
- var_nodes = [_to_persisted (el ) for el in nodes if type (el ) == emitters .Variable ]
108
- out_node = _to_persisted (_wrap_out_node (nodes ,out ))
109
- out_baseline_node = _to_persisted (_wrap_out_node (nodes ,_clone_var (out )))
108
+ var_nodes = [_to_persisted (el )
109
+ for el in nodes if type (el ) == emitters .Variable ]
110
+ out_node = _to_persisted (_wrap_out_node (nodes , out ))
111
+ out_baseline_node = _to_persisted (_wrap_out_node (nodes , _clone_var (out )))
110
112
out_node .obj .data .zero_ ()
111
113
includes = '#include "%s"' % filename
112
114
fndecl = 'int main(int argc, char *argv[])'
113
- calls = [el .emit_call (out_path ,'data' ) for el in var_nodes + [out_baseline_node , out_node ]]
115
+ calls = [el .emit_call (out_path , 'data' )
116
+ for el in var_nodes + [out_baseline_node , out_node ]]
114
117
fncall = '%s(%s);' % (fnname ,
115
- ', ' .join ([el .id_var_name () for el in var_nodes + [out_node ]]))
116
- equal_var = '%s_equal_%s' % (out_node .id_var_name (), out_baseline_node .id_var_name ())
117
- equal = out_node .emit_equal (equal_var ,out_baseline_node .id_var_name ())
118
+ ', ' .join ([el .id_var_name () for el in var_nodes + [out_node ]]))
119
+ equal_var = '%s_equal_%s' % (
120
+ out_node .id_var_name (), out_baseline_node .id_var_name ())
121
+ equal = out_node .emit_equal (equal_var , out_baseline_node .id_var_name ())
118
122
print_equal = 'printf("Test passed: %d\\ n",' + equal_var + ');'
119
- frees = [el .emit_free () for el in var_nodes + [out_baseline_node , out_node ]]
123
+ frees = [el .emit_free ()
124
+ for el in var_nodes + [out_baseline_node , out_node ]]
120
125
ret = 'return %s ? EXIT_SUCCESS : EXIT_FAILURE;' % equal_var
121
126
indent = ' ' * 2
122
- lines = [indent + el for el in '\n ' .join (calls + [fncall , equal , print_equal ] + frees + [ret ]).split ('\n ' ) if el ]
127
+ lines = [indent + el for el in '\n ' .join (
128
+ calls + [fncall , equal , print_equal ] + frees + [ret ]).split ('\n ' ) if el ]
123
129
lines = [includes , fndecl , '{' ] + lines + ['}' ]
124
130
return '\n ' .join (lines )
125
131
126
132
127
133
def compile (node , fnname , out_path , compile_test = False ):
128
- includedir = os .path .join (os .path .dirname (__file__ ),'..' ,'include' )
134
+ includedir = os .path .join (os .path .dirname (__file__ ), '..' , 'include' )
129
135
nodes = _traverse_graph (node )
130
136
if not os .path .isdir (out_path ):
131
137
os .mkdir (out_path )
132
- data_path = os .path .join (out_path ,'data' )
138
+ data_path = os .path .join (out_path , 'data' )
133
139
if not os .path .isdir (data_path ):
134
140
os .mkdir (data_path )
135
141
filename = "%s.h" % fnname
136
- src = _emit_c (nodes ,node ,fnname ,out_path )
137
- with open (os .path .join (out_path ,filename ),'w' ) as f :
142
+ src = _emit_c (nodes , node , fnname , out_path )
143
+ with open (os .path .join (out_path , filename ), 'w' ) as f :
138
144
f .write (src )
139
145
if compile_test :
140
146
test_filename = "%s_test.c" % fnname
141
- test_src = _emit_test (nodes ,node ,fnname ,filename ,out_path )
142
- with open (os .path .join (out_path ,test_filename ),'w' ) as f :
147
+ test_src = _emit_test (nodes , node , fnname , filename , out_path )
148
+ with open (os .path .join (out_path , test_filename ), 'w' ) as f :
143
149
f .write (test_src )
144
-
0 commit comments