-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathbackbone.py
71 lines (60 loc) · 4.4 KB
/
backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
import torchvision
from lydorn_utils import print_utils
def get_backbone(backbone_params):
set_download_dir()
if backbone_params["name"] == "unet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.unet import UNetBackbone
backbone = UNetBackbone(backbone_params["input_features"], backbone_params["features"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
elif backbone_params["name"] == "fcn50":
backbone = torchvision.models.segmentation.fcn_resnet50(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "fcn101":
backbone = torchvision.models.segmentation.fcn_resnet101(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(512, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "deeplab50":
backbone = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "deeplab101":
backbone = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=backbone_params["pretrained"],
num_classes=21)
backbone.classifier = torch.nn.Sequential(*list(backbone.classifier.children())[:-1],
torch.nn.Conv2d(256, backbone_params["features"], kernel_size=(1, 1),
stride=(1, 1)))
elif backbone_params["name"] == "unet_resnet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.unet_resnet import UNetResNetBackbone
backbone = UNetResNetBackbone(backbone_params["encoder_depth"], num_filters=backbone_params["num_filters"],
dropout_2d=backbone_params["dropout_2d"],
pretrained=backbone_params["pretrained"],
is_deconv=backbone_params["is_deconv"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
elif backbone_params["name"] == "ictnet":
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from frame_field_learning.ictnet import ICTNetBackbone
backbone = ICTNetBackbone(in_channels=backbone_params["in_channels"],
out_channels=backbone_params["out_channels"],
preset_model=backbone_params["preset_model"],
dropout_2d=backbone_params["dropout_2d"],
efficient=backbone_params["efficient"])
backbone = _SimpleSegmentationModel(backbone, classifier=torch.nn.Identity())
else:
print_utils.print_error("ERROR: config[\"backbone_params\"][\"name\"] = \"{}\" is an unknown backbone!"
"If it is a new backbone you want to use, "
"add it in backbone.py's get_backbone() function.".format(backbone_params["name"]))
raise RuntimeError("Specified backbone {} unknown".format(backbone_params["name"]))
return backbone
def set_download_dir():
os.environ['TORCH_HOME'] = 'models' # setting the environment variable