Skip to content

Commit 11f0f65

Browse files
committed
autopep8
1 parent 12e1900 commit 11f0f65

File tree

2 files changed

+113
-72
lines changed

2 files changed

+113
-72
lines changed

torch2c/__init__.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,21 @@
77
def _wrap(obj, prevfns=[]):
88
if obj.__class__ not in emitters._class_map:
99
raise Exception('%s does not have an Emitter' % obj.__class__)
10-
return emitters._class_map[obj.__class__](obj,prevfns)
10+
return emitters._class_map[obj.__class__](obj, prevfns)
1111

1212

1313
def _traverse_graph_recursive(out, el):
1414
if isinstance(el, Variable):
1515
var = _wrap(el)
1616
else:
1717
prevfns = []
18-
if hasattr(el,'previous_functions'):
18+
if hasattr(el, 'previous_functions'):
1919
prevfns = [f[0] for f in el.previous_functions]
20-
var = _wrap(el,prevfns)
20+
var = _wrap(el, prevfns)
2121
out.append(var)
2222
if hasattr(el, 'previous_functions'):
2323
for u in el.previous_functions:
24-
_traverse_graph_recursive(out,u[0])
24+
_traverse_graph_recursive(out, u[0])
2525

2626

2727
def _dedup_nodes(nodes):
@@ -37,10 +37,10 @@ def _dedup_nodes(nodes):
3737

3838
def _traverse_graph(node):
3939
nodes = []
40-
_traverse_graph_recursive(nodes,node.creator)
40+
_traverse_graph_recursive(nodes, node.creator)
4141
nodes.reverse()
4242
nodes = _dedup_nodes(nodes)
43-
var_dict = dict([(el.id,el) for el in nodes])
43+
var_dict = dict([(el.id, el) for el in nodes])
4444
prev_none_count = 0
4545
while True:
4646
for el in nodes:
@@ -69,28 +69,29 @@ def _emit_c(nodes, out, fnname, out_path):
6969
# 3. free parameters (gen name from fnname)
7070
# to avoid loading from disk at each forward
7171
var_nodes = [el for el in nodes if type(el) == emitters.Variable]
72-
out_node = _wrap_out_node(nodes,out)
72+
out_node = _wrap_out_node(nodes, out)
7373
# TODO: make it more general re: last_node?
7474
last_node = nodes[-1]
75-
ifndef = '#ifndef __%s__\n#define __%s__\n' % (2*(fnname.upper(),))
75+
ifndef = '#ifndef __%s__\n#define __%s__\n' % (2 * (fnname.upper(),))
7676
endif = '#endif'
7777
includes = '#include "TH.h"\n#include "THNN.h"\n#include "torch2c.h"\n'
78-
fndecl = 'void %s(%s)' % (fnname,
79-
', '.join([el.emit_decl() for el in var_nodes + [out_node]]))
80-
calls = [el.emit_call(out_path,'data') for el in nodes]
78+
fndecl = 'void %s(%s)' % (fnname,
79+
', '.join([el.emit_decl() for el in var_nodes + [out_node]]))
80+
calls = [el.emit_call(out_path, 'data') for el in nodes]
8181
copy_out = out_node.emit_copy(last_node.id_var_name())
8282
# TODO: be smarter re: frees
8383
# analyze calls backwards and free right after last use
8484
frees = [el.emit_free() for el in nodes]
8585
frees.reverse()
8686
indent = ' ' * 2
87-
lines = [indent + el for el in '\n'.join(calls + [copy_out] + frees).split('\n') if el]
87+
lines = [
88+
indent + el for el in '\n'.join(calls + [copy_out] + frees).split('\n') if el]
8889
lines = [ifndef, includes, fndecl, '{'] + lines + ['}', endif]
8990
return '\n'.join(lines)
9091

9192

9293
def _to_persisted(var_node):
93-
persisted = emitters.PersistedVariable(var_node.obj,[])
94+
persisted = emitters.PersistedVariable(var_node.obj, [])
9495
persisted.numtype = var_node.numtype
9596
return persisted
9697

@@ -104,41 +105,45 @@ def _clone_var(var):
104105

105106

106107
def _emit_test(nodes, out, fnname, filename, out_path):
107-
var_nodes = [_to_persisted(el) for el in nodes if type(el) == emitters.Variable]
108-
out_node = _to_persisted(_wrap_out_node(nodes,out))
109-
out_baseline_node = _to_persisted(_wrap_out_node(nodes,_clone_var(out)))
108+
var_nodes = [_to_persisted(el)
109+
for el in nodes if type(el) == emitters.Variable]
110+
out_node = _to_persisted(_wrap_out_node(nodes, out))
111+
out_baseline_node = _to_persisted(_wrap_out_node(nodes, _clone_var(out)))
110112
out_node.obj.data.zero_()
111113
includes = '#include "%s"' % filename
112114
fndecl = 'int main(int argc, char *argv[])'
113-
calls = [el.emit_call(out_path,'data') for el in var_nodes + [out_baseline_node, out_node]]
115+
calls = [el.emit_call(out_path, 'data')
116+
for el in var_nodes + [out_baseline_node, out_node]]
114117
fncall = '%s(%s);' % (fnname,
115-
', '.join([el.id_var_name() for el in var_nodes + [out_node]]))
116-
equal_var = '%s_equal_%s' % (out_node.id_var_name(), out_baseline_node.id_var_name())
117-
equal = out_node.emit_equal(equal_var,out_baseline_node.id_var_name())
118+
', '.join([el.id_var_name() for el in var_nodes + [out_node]]))
119+
equal_var = '%s_equal_%s' % (
120+
out_node.id_var_name(), out_baseline_node.id_var_name())
121+
equal = out_node.emit_equal(equal_var, out_baseline_node.id_var_name())
118122
print_equal = 'printf("Test passed: %d\\n",' + equal_var + ');'
119-
frees = [el.emit_free() for el in var_nodes + [out_baseline_node, out_node]]
123+
frees = [el.emit_free()
124+
for el in var_nodes + [out_baseline_node, out_node]]
120125
ret = 'return %s ? EXIT_SUCCESS : EXIT_FAILURE;' % equal_var
121126
indent = ' ' * 2
122-
lines = [indent + el for el in '\n'.join(calls + [fncall, equal, print_equal] + frees + [ret]).split('\n') if el]
127+
lines = [indent + el for el in '\n'.join(
128+
calls + [fncall, equal, print_equal] + frees + [ret]).split('\n') if el]
123129
lines = [includes, fndecl, '{'] + lines + ['}']
124130
return '\n'.join(lines)
125131

126132

127133
def compile(node, fnname, out_path, compile_test=False):
128-
includedir = os.path.join(os.path.dirname(__file__),'..','include')
134+
includedir = os.path.join(os.path.dirname(__file__), '..', 'include')
129135
nodes = _traverse_graph(node)
130136
if not os.path.isdir(out_path):
131137
os.mkdir(out_path)
132-
data_path = os.path.join(out_path,'data')
138+
data_path = os.path.join(out_path, 'data')
133139
if not os.path.isdir(data_path):
134140
os.mkdir(data_path)
135141
filename = "%s.h" % fnname
136-
src = _emit_c(nodes,node,fnname,out_path)
137-
with open(os.path.join(out_path,filename),'w') as f:
142+
src = _emit_c(nodes, node, fnname, out_path)
143+
with open(os.path.join(out_path, filename), 'w') as f:
138144
f.write(src)
139145
if compile_test:
140146
test_filename = "%s_test.c" % fnname
141-
test_src = _emit_test(nodes,node,fnname,filename,out_path)
142-
with open(os.path.join(out_path,test_filename),'w') as f:
147+
test_src = _emit_test(nodes, node, fnname, filename, out_path)
148+
with open(os.path.join(out_path, test_filename), 'w') as f:
143149
f.write(test_src)
144-

0 commit comments

Comments
 (0)