Skip to content

Commit

Permalink
Update config and notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
souvikmajumder26 committed Jun 15, 2023
1 parent 551c2c1 commit e6df2fe
Show file tree
Hide file tree
Showing 4 changed files with 1,299 additions and 7 deletions.
4 changes: 3 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ vars:
encoder: 'efficientnet-b0' # choose from the pre-trained encoders listed here: https://smp.readthedocs.io/en/latest/encoders_timm.html
encoder_weights: 'imagenet'
activation: 'softmax2d' # sigmoid for binary classification, softmax2d for multi-class classification
init_lr: 0.0003 # initial learning rate
epochs: 20 # fine-tune model training
device: 'cpu' # 'cpu' or 'cuda' - edit according to your device
train_classes: ['background', 'building', 'woodland', 'water'] # not training on 'road' class since it's instances in the data is too less
all_classes: ['background', 'building', 'woodland', 'water', 'road'] # all the classes present in the dataset
train_classes: ['background', 'building', 'woodland', 'water'] # not training on 'road' class (need more fine-tuning)
test_classes: ['background', 'building', 'water'] # the class prompt - edit according to what you want to have in the output masks
1,287 changes: 1,287 additions & 0 deletions notebooks/training.ipynb

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
encoder = slice_config['vars']['encoder'] # the backbone/encoder of the model
encoder_weights = slice_config['vars']['encoder_weights']
activation = slice_config['vars']['activation']
init_lr = slice_config['vars']['init_lr']
epochs = slice_config['vars']['epochs']
all_classes = slice_config['vars']['all_classes']
classes = slice_config['vars']['train_classes']
device = slice_config['vars']['device']

Expand Down Expand Up @@ -140,15 +142,17 @@
train_dataset = SegmentationDataset(
x_train_dir,
y_train_dir,
all_classes=all_classes,
classes=classes,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
classes=classes,
)
val_dataset = SegmentationDataset(
x_val_dir,
y_val_dir,
preprocessing=get_preprocessing(preprocessing_fn),
all_classes=all_classes,
classes=classes,
preprocessing=get_preprocessing(preprocessing_fn),
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
Expand All @@ -171,7 +175,7 @@

try:
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=0.0003),
dict(params=model.parameters(), lr=init_lr),
])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
print("\nInitialized the optimizer!")
Expand Down
5 changes: 2 additions & 3 deletions src/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class SegmentationDataset(Dataset):
(e.g. noralization, shape manipulation, etc.)
"""

CLASSES = ['background', 'building', 'woodland', 'water', 'road']

def __init__(
self,
images_dir,
masks_dir,
all_classes,
classes=None,
augmentation=None,
preprocessing=None,
Expand All @@ -33,7 +32,7 @@ def __init__(
self.masks = [os.path.join(masks_dir, image_id) for image_id in self.ids]

# convert str names to class values on masks
self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
self.class_values = [all_classes.index(cls.lower()) for cls in classes]

self.augmentation = augmentation
self.preprocessing = preprocessing
Expand Down

0 comments on commit e6df2fe

Please sign in to comment.