Skip to content

Commit 1432199

Browse files
Linking functions to instances and not classes. (fixes Cadene#71)
1 parent b89a5b5 commit 1432199

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

pretrainedmodels/models/torchvision_models.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torchvision.models as models
44
import torch.utils.model_zoo as model_zoo
55
import torch.nn.functional as F
6+
import types
67

78
#################################################################
89
# You can find the definitions of those models here:
@@ -141,9 +142,9 @@ def forward(self, input):
141142
return x
142143

143144
# Modify methods
144-
setattr(model.__class__, 'features', features)
145-
setattr(model.__class__, 'logits', logits)
146-
setattr(model.__class__, 'forward', forward)
145+
model.features = types.MethodType(features, model)
146+
model.logits = types.MethodType(logits, model)
147+
model.forward = types.MethodType(forward, model)
147148
return model
148149

149150
def alexnet(num_classes=1000, pretrained='imagenet'):
@@ -179,8 +180,8 @@ def forward(self, input):
179180
return x
180181

181182
# Modify methods
182-
setattr(model.__class__, 'logits', logits)
183-
setattr(model.__class__, 'forward', forward)
183+
model.logits = types.MethodType(logits, model)
184+
model.forward = types.MethodType(forward, model)
184185
return model
185186

186187
def densenet121(num_classes=1000, pretrained='imagenet'):
@@ -284,9 +285,9 @@ def forward(self, input):
284285
return x
285286

286287
# Modify methods
287-
setattr(model.__class__, 'features', features)
288-
setattr(model.__class__, 'logits', logits)
289-
setattr(model.__class__, 'forward', forward)
288+
model.features = types.MethodType(features, model)
289+
model.logits = types.MethodType(logits, model)
290+
model.forward = types.MethodType(forward, model)
290291
return model
291292

292293
###############################################################
@@ -321,9 +322,9 @@ def forward(self, input):
321322
return x
322323

323324
# Modify methods
324-
setattr(model.__class__, 'features', features)
325-
setattr(model.__class__, 'logits', logits)
326-
setattr(model.__class__, 'forward', forward)
325+
model.features = types.MethodType(features, model)
326+
model.logits = types.MethodType(logits, model)
327+
model.forward = types.MethodType(forward, model)
327328
return model
328329

329330
def resnet18(num_classes=1000, pretrained='imagenet'):
@@ -402,8 +403,8 @@ def forward(self, input):
402403
return x
403404

404405
# Modify methods
405-
setattr(model.__class__, 'logits', logits)
406-
setattr(model.__class__, 'forward', forward)
406+
model.logits = types.MethodType(logits, model)
407+
model.forward = types.MethodType(forward, model)
407408
return model
408409

409410
def squeezenet1_0(num_classes=1000, pretrained='imagenet'):
@@ -468,9 +469,9 @@ def forward(self, input):
468469
return x
469470

470471
# Modify methods
471-
setattr(model.__class__, 'features', features)
472-
setattr(model.__class__, 'logits', logits)
473-
setattr(model.__class__, 'forward', forward)
472+
model.features = types.MethodType(features, model)
473+
model.logits = types.MethodType(logits, model)
474+
model.forward = types.MethodType(forward, model)
474475
return model
475476

476477
def vgg11(num_classes=1000, pretrained='imagenet'):

0 commit comments

Comments
 (0)