Skip to content

Commit

Permalink
added model
Browse files Browse the repository at this point in the history
  • Loading branch information
fvviz committed Feb 1, 2024
1 parent e493a35 commit 110a743
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 0 deletions.
29 changes: 29 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class SatDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)

def __len__(self):
return len(self.images)

def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
mask_path = os.path.join(self.mask_dir, self.images[index].replace("sat.jpg", "mask.png"))
image = np.array(Image.open(img_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
mask[mask == 255.0] = 1.0

if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]

return image, mask
55 changes: 55 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch.nn as nn
import torch


class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, X):
return self.conv(X)


class UNET(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(UNET, self).__init__()
self.ups = nn.ModuleList()
self.downs =nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature

for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(in_channels=feature*2, out_channels= feature, kernel_size=2, stride=2)) # multiply feature by 2 to account for skip connection
self.ups.append(DoubleConv(feature*2, feature))

self.bottleneck = DoubleConv(in_channels=features[-1], out_channels=features[-1]*2)
self.final= nn.Conv2d(features[0], out_channels=out_channels, kernel_size=1)

def forward(self, X):
skip_connections = []
for i, down in enumerate(self.downs):
X = down(X)
skip_connections.append(X)
X = self.pool(X)

X = self.bottleneck(X)
skip_connections = list(reversed(skip_connections))

for i in range(0, len(self.ups), 2):
X = self.ups[i](X)
skip_conn = skip_connections[i//2]
concat_skip = torch.cat((skip_conn,X), dim=1)
X = self.ups[i+1](concat_skip)

return self.final(X)
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch
numpy
matplotlib
pillow

25 changes: 25 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torchvision

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)

def load_checkpoint(checkpoint, model):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])

def save_predictions_as_imgs(
loader, model, folder="saved_images/", device="mps"
):
model.eval()
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
torchvision.utils.save_image(
preds, f"{folder}/pred_{idx}.png"
)
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

model.train()

0 comments on commit 110a743

Please sign in to comment.