Skip to content

Commit a1c453c

Browse files
committed
Rename wrapper to emitter, format code
1 parent cab28ca commit a1c453c

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

torch2c/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
22
from torch.autograd import Variable
33
import os
4-
from . import wrappers
4+
from . import emitters
5+
56

67
def _wrap(obj, prevfns=[]):
7-
return wrappers._class_map[obj.__class__](obj,prevfns)
8+
return emitters._class_map[obj.__class__](obj,prevfns)
9+
810

911
def _traverse_graph_recursive(out, el):
1012
if isinstance(el, Variable):
@@ -19,6 +21,7 @@ def _traverse_graph_recursive(out, el):
1921
for u in el.previous_functions:
2022
_traverse_graph_recursive(out,u[0])
2123

24+
2225
def _traverse_graph(node):
2326
nodes = []
2427
_traverse_graph_recursive(nodes,node.creator)
@@ -28,20 +31,22 @@ def _traverse_graph(node):
2831
el.infer_type(var_dict)
2932
return nodes
3033

34+
3135
def _wrap_out_node(nodes, out):
3236
out_node = _wrap(out)
3337
out_creator_id = id(out.creator)
3438
out_creator_node = [el for el in nodes if el.id == out_creator_id][0]
3539
out_node.infer_type({out_creator_id: out_creator_node})
3640
return out_node
3741

42+
3843
def _generate_c(nodes, out, fnname, out_path):
3944
# TODO: generate three functions
4045
# 1. load parameters (gen name from fnname)
4146
# 2. run forward
4247
# 3. free parameters (gen name from fnname)
4348
# to avoid loading from disk at each forward
44-
var_nodes = [el for el in nodes if type(el) == wrappers.Variable]
49+
var_nodes = [el for el in nodes if type(el) == emitters.Variable]
4550
out_node = _wrap_out_node(nodes,out)
4651
# TODO: make it more general re: last_node?
4752
last_node = nodes[-1]
@@ -61,20 +66,23 @@ def _generate_c(nodes, out, fnname, out_path):
6166
lines = [ifndef, includes, fndecl, '{'] + lines + ['}', endif]
6267
return '\n'.join(lines)
6368

69+
6470
def _to_persisted(var_node):
65-
persisted = wrappers.PersistedVariable(var_node.obj,[])
71+
persisted = emitters.PersistedVariable(var_node.obj,[])
6672
persisted.numtype = var_node.numtype
6773
return persisted
6874

75+
6976
def _clone_var(var):
7077
out = Variable(data=var.data.clone(),
7178
creator=var.creator,
7279
requires_grad=var.requires_grad,
7380
volatile=var.volatile)
7481
return out
7582

83+
7684
def _generate_test(nodes, out, fnname, filename, out_path):
77-
var_nodes = [_to_persisted(el) for el in nodes if type(el) == wrappers.Variable]
85+
var_nodes = [_to_persisted(el) for el in nodes if type(el) == emitters.Variable]
7886
out_node = _to_persisted(_wrap_out_node(nodes,out))
7987
out_baseline_node = _to_persisted(_wrap_out_node(nodes,_clone_var(out)))
8088
out_node.obj.data.zero_()
@@ -93,6 +101,7 @@ def _generate_test(nodes, out, fnname, filename, out_path):
93101
lines = [includes, fndecl, '{'] + lines + ['}']
94102
return '\n'.join(lines)
95103

104+
96105
def compile(node, fnname, out_path, compile_test=False):
97106
nodes = _traverse_graph(node)
98107
if not os.path.isdir(out_path):
@@ -110,4 +119,3 @@ def compile(node, fnname, out_path, compile_test=False):
110119
with open(os.path.join(out_path,test_filename),'w') as f:
111120
f.write(test_src)
112121

113-

torch2c/wrappers.py renamed to torch2c/emitters.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
_class_map = {}
1616

17-
def register(wrapperClass, torchClass):
18-
_class_map[torchClass] = wrapperClass
1917

18+
def register(emitterClass, torchClass):
19+
_class_map[torchClass] = emitterClass
2020

21-
class Wrapper(object):
21+
22+
class Emitter(object):
2223

2324
def __init__(self, obj, prevfns):
2425
self.id = id(obj)
@@ -148,20 +149,21 @@ def tensor_meta_tpl(size_name, stride_name, size, stride=None):
148149

149150

150151
#####################
151-
# Wrapper subclasses
152+
# Emitter subclasses
152153
#####################
153154

154155

155-
class Variable(Wrapper):
156+
class Variable(Emitter):
156157

157158
def __init__(self, obj, prevfns):
158-
Wrapper.__init__(self, obj, prevfns)
159+
Emitter.__init__(self, obj, prevfns)
159160

160161
def infer_type(self, var_dict):
161162
self.numtype = self.obj.data.__class__.__name__[:len('Tensor')-1]
162163

163164
register(Variable, torch.autograd.Variable)
164165

166+
165167
def persist_tensor(tensor, name, out_path, datadir, size_name='size_$id', stride_name='stride_$id'):
166168
contiguous = tensor.contiguous()
167169
filename = '%s.th' % name
@@ -177,6 +179,7 @@ def persist_tensor(tensor, name, out_path, datadir, size_name='size_$id', stride
177179
meta, meta_free = tensor_meta_tpl(size_name,stride_name,size,stride)
178180
return os.path.join(datadir,filename), meta, meta_free
179181

182+
180183
# TODO: add this function to an auxiliary file
181184
# call it something like TH${T}Storage_newFromFile(filename);
182185
def read_storage(storage_name,filepath,numtype):
@@ -205,6 +208,7 @@ def read_storage(storage_name,filepath,numtype):
205208
'''
206209
return Template(tpl).substitute(subs)
207210

211+
208212
class PersistedVariable(Variable):
209213

210214
def __init__(self, obj, prevfns):
@@ -227,6 +231,7 @@ def free_tpl(self):
227231
TH${T}Storage_free(storage_$id);
228232
'''
229233

234+
230235
class Parameter(PersistedVariable):
231236

232237
def __init__(self, obj, prevfns):
@@ -235,10 +240,10 @@ def __init__(self, obj, prevfns):
235240
register(Parameter, torch.nn.parameter.Parameter)
236241

237242

238-
class Linear(Wrapper):
243+
class Linear(Emitter):
239244

240245
def __init__(self, obj, prevfns):
241-
Wrapper.__init__(self, obj, prevfns)
246+
Emitter.__init__(self, obj, prevfns)
242247

243248
try:
244249
input, weight, bias = [id(el) for el in prevfns]
@@ -262,14 +267,13 @@ def free_tpl(self):
262267
TH${T}Tensor_free(addBuffer_$id);
263268
'''
264269

265-
266270
register(Linear, torch.nn._functions.linear.Linear)
267271

268272

269-
class LogSoftmax(Wrapper):
273+
class LogSoftmax(Emitter):
270274

271275
def __init__(self, obj, prevfns):
272-
Wrapper.__init__(self, obj, prevfns)
276+
Emitter.__init__(self, obj, prevfns)
273277
self.def_vars({'input': id(prevfns[0])})
274278
self.infer_type_var = 'input'
275279

@@ -287,10 +291,10 @@ def free_tpl(self):
287291
register(LogSoftmax, torch.nn._functions.thnn.auto.LogSoftmax)
288292

289293

290-
class Threshold(Wrapper):
294+
class Threshold(Emitter):
291295

292296
def __init__(self, obj, prevfns):
293-
Wrapper.__init__(self, obj, prevfns)
297+
Emitter.__init__(self, obj, prevfns)
294298
self.def_vars({
295299
'input': id(prevfns[0]),
296300
})
@@ -315,10 +319,10 @@ def free_tpl(self):
315319
register(Threshold, torch.nn._functions.thnn.auto.Threshold)
316320

317321

318-
class Noop(Wrapper):
322+
class Noop(Emitter):
319323

320324
def __init__(self, obj, prevfns):
321-
Wrapper.__init__(self, obj, prevfns)
325+
Emitter.__init__(self, obj, prevfns)
322326
self.def_vars({'input': id(prevfns[0])})
323327
self.infer_type_var = 'input'
324328

@@ -334,10 +338,10 @@ def free_tpl(self):
334338
register(Noop, torch.nn._functions.dropout.FeatureDropout)
335339

336340

337-
class View(Wrapper):
341+
class View(Emitter):
338342

339343
def __init__(self, obj, prevfns):
340-
Wrapper.__init__(self, obj, prevfns)
344+
Emitter.__init__(self, obj, prevfns)
341345
self.def_vars({'input': id(prevfns[0])})
342346
self.infer_type_var = 'input'
343347

@@ -359,10 +363,10 @@ def free_tpl(self):
359363
register(View, torch.autograd._functions.tensor.View)
360364

361365

362-
class MaxPool2d(Wrapper):
366+
class MaxPool2d(Emitter):
363367

364368
def __init__(self, obj, prevfns):
365-
Wrapper.__init__(self, obj, prevfns)
369+
Emitter.__init__(self, obj, prevfns)
366370
self.def_vars({
367371
'input': id(prevfns[0])
368372
})
@@ -392,10 +396,10 @@ def free_tpl(self):
392396
register(MaxPool2d, torch.nn._functions.thnn.pooling.MaxPool2d)
393397

394398

395-
class ConvNd(Wrapper):
399+
class ConvNd(Emitter):
396400

397401
def __init__(self, obj, prevfns):
398-
Wrapper.__init__(self, obj, prevfns)
402+
Emitter.__init__(self, obj, prevfns)
399403
self.def_vars({'input': id(prevfns[0]),
400404
'weight': id(prevfns[1]),
401405
'bias': id(prevfns[2])})

0 commit comments

Comments
 (0)