|
3 | 3 | import torchvision.models as models
|
4 | 4 | import torch.utils.model_zoo as model_zoo
|
5 | 5 | import torch.nn.functional as F
|
| 6 | +import types |
6 | 7 |
|
7 | 8 | #################################################################
|
8 | 9 | # You can find the definitions of those models here:
|
@@ -141,9 +142,9 @@ def forward(self, input):
|
141 | 142 | return x
|
142 | 143 |
|
143 | 144 | # Modify methods
|
144 |
| - setattr(model.__class__, 'features', features) |
145 |
| - setattr(model.__class__, 'logits', logits) |
146 |
| - setattr(model.__class__, 'forward', forward) |
| 145 | + model.features = types.MethodType(features, model) |
| 146 | + model.logits = types.MethodType(logits, model) |
| 147 | + model.forward = types.MethodType(forward, model) |
147 | 148 | return model
|
148 | 149 |
|
149 | 150 | def alexnet(num_classes=1000, pretrained='imagenet'):
|
@@ -179,8 +180,8 @@ def forward(self, input):
|
179 | 180 | return x
|
180 | 181 |
|
181 | 182 | # Modify methods
|
182 |
| - setattr(model.__class__, 'logits', logits) |
183 |
| - setattr(model.__class__, 'forward', forward) |
| 183 | + model.logits = types.MethodType(logits, model) |
| 184 | + model.forward = types.MethodType(forward, model) |
184 | 185 | return model
|
185 | 186 |
|
186 | 187 | def densenet121(num_classes=1000, pretrained='imagenet'):
|
@@ -284,9 +285,9 @@ def forward(self, input):
|
284 | 285 | return x
|
285 | 286 |
|
286 | 287 | # Modify methods
|
287 |
| - setattr(model.__class__, 'features', features) |
288 |
| - setattr(model.__class__, 'logits', logits) |
289 |
| - setattr(model.__class__, 'forward', forward) |
| 288 | + model.features = types.MethodType(features, model) |
| 289 | + model.logits = types.MethodType(logits, model) |
| 290 | + model.forward = types.MethodType(forward, model) |
290 | 291 | return model
|
291 | 292 |
|
292 | 293 | ###############################################################
|
@@ -321,9 +322,9 @@ def forward(self, input):
|
321 | 322 | return x
|
322 | 323 |
|
323 | 324 | # Modify methods
|
324 |
| - setattr(model.__class__, 'features', features) |
325 |
| - setattr(model.__class__, 'logits', logits) |
326 |
| - setattr(model.__class__, 'forward', forward) |
| 325 | + model.features = types.MethodType(features, model) |
| 326 | + model.logits = types.MethodType(logits, model) |
| 327 | + model.forward = types.MethodType(forward, model) |
327 | 328 | return model
|
328 | 329 |
|
329 | 330 | def resnet18(num_classes=1000, pretrained='imagenet'):
|
@@ -402,8 +403,8 @@ def forward(self, input):
|
402 | 403 | return x
|
403 | 404 |
|
404 | 405 | # Modify methods
|
405 |
| - setattr(model.__class__, 'logits', logits) |
406 |
| - setattr(model.__class__, 'forward', forward) |
| 406 | + model.logits = types.MethodType(logits, model) |
| 407 | + model.forward = types.MethodType(forward, model) |
407 | 408 | return model
|
408 | 409 |
|
409 | 410 | def squeezenet1_0(num_classes=1000, pretrained='imagenet'):
|
@@ -468,9 +469,9 @@ def forward(self, input):
|
468 | 469 | return x
|
469 | 470 |
|
470 | 471 | # Modify methods
|
471 |
| - setattr(model.__class__, 'features', features) |
472 |
| - setattr(model.__class__, 'logits', logits) |
473 |
| - setattr(model.__class__, 'forward', forward) |
| 472 | + model.features = types.MethodType(features, model) |
| 473 | + model.logits = types.MethodType(logits, model) |
| 474 | + model.forward = types.MethodType(forward, model) |
474 | 475 | return model
|
475 | 476 |
|
476 | 477 | def vgg11(num_classes=1000, pretrained='imagenet'):
|
|
0 commit comments