14
14
15
15
_class_map = {}
16
16
17
- def register (wrapperClass , torchClass ):
18
- _class_map [torchClass ] = wrapperClass
19
17
18
+ def register (emitterClass , torchClass ):
19
+ _class_map [torchClass ] = emitterClass
20
20
21
- class Wrapper (object ):
21
+
22
+ class Emitter (object ):
22
23
23
24
def __init__ (self , obj , prevfns ):
24
25
self .id = id (obj )
@@ -148,20 +149,21 @@ def tensor_meta_tpl(size_name, stride_name, size, stride=None):
148
149
149
150
150
151
#####################
151
- # Wrapper subclasses
152
+ # Emitter subclasses
152
153
#####################
153
154
154
155
155
- class Variable (Wrapper ):
156
+ class Variable (Emitter ):
156
157
157
158
def __init__ (self , obj , prevfns ):
158
- Wrapper .__init__ (self , obj , prevfns )
159
+ Emitter .__init__ (self , obj , prevfns )
159
160
160
161
def infer_type (self , var_dict ):
161
162
self .numtype = self .obj .data .__class__ .__name__ [:len ('Tensor' )- 1 ]
162
163
163
164
register (Variable , torch .autograd .Variable )
164
165
166
+
165
167
def persist_tensor (tensor , name , out_path , datadir , size_name = 'size_$id' , stride_name = 'stride_$id' ):
166
168
contiguous = tensor .contiguous ()
167
169
filename = '%s.th' % name
@@ -177,6 +179,7 @@ def persist_tensor(tensor, name, out_path, datadir, size_name='size_$id', stride
177
179
meta , meta_free = tensor_meta_tpl (size_name ,stride_name ,size ,stride )
178
180
return os .path .join (datadir ,filename ), meta , meta_free
179
181
182
+
180
183
# TODO: add this function to an auxiliary file
181
184
# call it something like TH${T}Storage_newFromFile(filename);
182
185
def read_storage (storage_name ,filepath ,numtype ):
@@ -205,6 +208,7 @@ def read_storage(storage_name,filepath,numtype):
205
208
'''
206
209
return Template (tpl ).substitute (subs )
207
210
211
+
208
212
class PersistedVariable (Variable ):
209
213
210
214
def __init__ (self , obj , prevfns ):
@@ -227,6 +231,7 @@ def free_tpl(self):
227
231
TH${T}Storage_free(storage_$id);
228
232
'''
229
233
234
+
230
235
class Parameter (PersistedVariable ):
231
236
232
237
def __init__ (self , obj , prevfns ):
@@ -235,10 +240,10 @@ def __init__(self, obj, prevfns):
235
240
register (Parameter , torch .nn .parameter .Parameter )
236
241
237
242
238
- class Linear (Wrapper ):
243
+ class Linear (Emitter ):
239
244
240
245
def __init__ (self , obj , prevfns ):
241
- Wrapper .__init__ (self , obj , prevfns )
246
+ Emitter .__init__ (self , obj , prevfns )
242
247
243
248
try :
244
249
input , weight , bias = [id (el ) for el in prevfns ]
@@ -262,14 +267,13 @@ def free_tpl(self):
262
267
TH${T}Tensor_free(addBuffer_$id);
263
268
'''
264
269
265
-
266
270
register (Linear , torch .nn ._functions .linear .Linear )
267
271
268
272
269
- class LogSoftmax (Wrapper ):
273
+ class LogSoftmax (Emitter ):
270
274
271
275
def __init__ (self , obj , prevfns ):
272
- Wrapper .__init__ (self , obj , prevfns )
276
+ Emitter .__init__ (self , obj , prevfns )
273
277
self .def_vars ({'input' : id (prevfns [0 ])})
274
278
self .infer_type_var = 'input'
275
279
@@ -287,10 +291,10 @@ def free_tpl(self):
287
291
register (LogSoftmax , torch .nn ._functions .thnn .auto .LogSoftmax )
288
292
289
293
290
- class Threshold (Wrapper ):
294
+ class Threshold (Emitter ):
291
295
292
296
def __init__ (self , obj , prevfns ):
293
- Wrapper .__init__ (self , obj , prevfns )
297
+ Emitter .__init__ (self , obj , prevfns )
294
298
self .def_vars ({
295
299
'input' : id (prevfns [0 ]),
296
300
})
@@ -315,10 +319,10 @@ def free_tpl(self):
315
319
register (Threshold , torch .nn ._functions .thnn .auto .Threshold )
316
320
317
321
318
- class Noop (Wrapper ):
322
+ class Noop (Emitter ):
319
323
320
324
def __init__ (self , obj , prevfns ):
321
- Wrapper .__init__ (self , obj , prevfns )
325
+ Emitter .__init__ (self , obj , prevfns )
322
326
self .def_vars ({'input' : id (prevfns [0 ])})
323
327
self .infer_type_var = 'input'
324
328
@@ -334,10 +338,10 @@ def free_tpl(self):
334
338
register (Noop , torch .nn ._functions .dropout .FeatureDropout )
335
339
336
340
337
- class View (Wrapper ):
341
+ class View (Emitter ):
338
342
339
343
def __init__ (self , obj , prevfns ):
340
- Wrapper .__init__ (self , obj , prevfns )
344
+ Emitter .__init__ (self , obj , prevfns )
341
345
self .def_vars ({'input' : id (prevfns [0 ])})
342
346
self .infer_type_var = 'input'
343
347
@@ -359,10 +363,10 @@ def free_tpl(self):
359
363
register (View , torch .autograd ._functions .tensor .View )
360
364
361
365
362
- class MaxPool2d (Wrapper ):
366
+ class MaxPool2d (Emitter ):
363
367
364
368
def __init__ (self , obj , prevfns ):
365
- Wrapper .__init__ (self , obj , prevfns )
369
+ Emitter .__init__ (self , obj , prevfns )
366
370
self .def_vars ({
367
371
'input' : id (prevfns [0 ])
368
372
})
@@ -392,10 +396,10 @@ def free_tpl(self):
392
396
register (MaxPool2d , torch .nn ._functions .thnn .pooling .MaxPool2d )
393
397
394
398
395
- class ConvNd (Wrapper ):
399
+ class ConvNd (Emitter ):
396
400
397
401
def __init__ (self , obj , prevfns ):
398
- Wrapper .__init__ (self , obj , prevfns )
402
+ Emitter .__init__ (self , obj , prevfns )
399
403
self .def_vars ({'input' : id (prevfns [0 ]),
400
404
'weight' : id (prevfns [1 ]),
401
405
'bias' : id (prevfns [2 ])})
0 commit comments