Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
fvviz committed Feb 1, 2024
1 parent 110a743 commit 9424bac
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
44 changes: 43 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,46 @@ def __getitem__(self, index):
image = augmentations["image"]
mask = augmentations["mask"]

return image, mask
return image, mask

def get_loaders(
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers=4,
pin_memory=True,
):
train_set = SatDataset(
image_dir=train_dir,
mask_dir=train_maskdir,
transform=train_transform,

)

train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=True,
)

val_set = SatDataset(
image_dir = val_dir,
mask_dir= val_maskdir,
transform= val_transform
)

val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
)

return train_loader, val_loader
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ torch
numpy
matplotlib
pillow

splitfolders
albumentations
tqdm
42 changes: 41 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import torchvision
import torch
import os
import shutil
import splitfolders
import tqdm

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
Expand All @@ -22,4 +27,39 @@ def save_predictions_as_imgs(
)
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

model.train()
model.train()

def organise_data(input_dir = 'road-detection/train', split_ratio =(.8, 0.1,0.1) ):
output_sat_dir = os.path.join(input_dir, 'sat')
output_mask_dir = os.path.join(input_dir, 'mask')

os.makedirs(output_sat_dir, exist_ok=True)
os.makedirs(output_mask_dir, exist_ok=True)


for filename in os.listdir(input_dir):
if filename.endswith('sat.jpg'):
shutil.move(os.path.join(input_dir, filename), os.path.join(output_sat_dir, filename))
elif filename.endswith('mask.png'):
shutil.move(os.path.join(input_dir, filename), os.path.join(output_mask_dir, filename))
splitfolders.ratio('road-detection/train', output="road-detection/organised_data", seed=1337, ratio=split_ratio)


def train_fn(loader, model, optimizer, loss_fn, scaler):
loop = tqdm(loader)

for batch_idx, (data, targets) in enumerate(loop):
data = data.to(device='cuda')
targets = targets.float().unsqueeze(1).to(device='cuda')


with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)

optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

loop.set_postfix(loss=loss.item())

0 comments on commit 9424bac

Please sign in to comment.