Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fvviz committed Feb 2, 2024
1 parent a60819e commit 7c5afc2
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ MANIFEST
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

.DS_Store
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
flagged/

# Unit test / coverage reports
htmlcov/
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions flagged/log.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Input Image,Threshold,Sliding window size,Stride size,Output Image,flag,username,timestamp
"{""path"":""flagged/Input Image/0d52dbb9b40ecef7bd7b/vit2 1-2.jpg"",""url"":""http://127.0.0.1:7860/file=/private/var/folders/zy/z7lyg4k560j0lf908qw4xb440000gn/T/gradio/6dccc1f36697087dc5b85fa2818b805b00cea9f9/vit2 1-2.jpg"",""size"":258179,""orig_name"":""vit2 (1)-2.jpg"",""mime_type"":""""}",0.5,256,256,"{""path"":""flagged/Output Image/59a7eb66bc5e479b1da1/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}",,,2024-02-02 14:26:19.097996
"{""path"":""flagged/Input Image/469a6f8b25997c85fce4/vit2 1-2.jpg"",""url"":""http://127.0.0.1:7860/file=/private/var/folders/zy/z7lyg4k560j0lf908qw4xb440000gn/T/gradio/6dccc1f36697087dc5b85fa2818b805b00cea9f9/vit2 1-2.jpg"",""size"":258179,""orig_name"":""vit2 (1)-2.jpg"",""mime_type"":""""}",0.5,256,256,"{""path"":""flagged/Output Image/b859c2e2ac2883510169/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}",,,2024-02-02 14:26:21.500175
20 changes: 20 additions & 0 deletions gradio_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import gradio as gr
from sliding_window import run_sliding_window_pil


iface = gr.Interface(
fn=run_sliding_window_pil,
inputs=[gr.Image(type="pil", label="Input Image"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold"),
gr.Dropdown(
[256, 128, 64, 32], label="Sliding window size", value=256
),
gr.Dropdown(
[256, 128, 64, 32], label="Stride size", value=256
)]
,
outputs=gr.Image(type="pil", label="Output Image"),
title="Satellite road detection"
)

iface.launch()
95 changes: 95 additions & 0 deletions sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn.functional as F
from PIL import Image
import argparse

from torchvision import transforms
import torchvision
from model import UNET

def sliding_window_inference(model, input_tensor, window_size, stride, threshold):
_, _, height, width = input_tensor.size()
result_tensor = torch.zeros((1, 1, height, width), device=input_tensor.device)
count_tensor = torch.zeros((1, 1, height, width), device=input_tensor.device)

model.eval()
for h in range(0, height - window_size[2] + 1, stride):
for w in range(0, width - window_size[3] + 1, stride):
patch = input_tensor[:, :, h:h+window_size[2], w:w+window_size[3]]


with torch.no_grad():
output_patch = torch.sigmoid(model(patch))
output_patch =(output_patch>threshold).float()


result_tensor[:, :, h:h+window_size[2], w:w+window_size[3]] += output_patch
count_tensor[:, :, h:h+window_size[2], w:w+window_size[3]] += 1

result_tensor /= count_tensor
model.train()

return result_tensor

def run_sliding_window_pil(image, threshold, window_pixels, stride=64):
model = UNET()
checkpoint = torch.load('checkpoints/epoch_3_checkpoint.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
])

input_image = transform(image).unsqueeze(0)

window_size = (1, 3, window_pixels, window_pixels)
print("running sliding window")
output = sliding_window_inference(model, input_image, window_size, stride, threshold)

normalized_output = output
denormalized_output = normalized_output * torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1) + torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1)

# Converting torch tensor to PIL Image for Gradio compatibility
denormalized_output_pil = transforms.ToPILImage()(denormalized_output.squeeze(0))

return denormalized_output_pil

def run_sliding_window(image_dir, output_dir, threshold):
model = UNET()
checkpoint = torch.load('checkpoints/epoch_3_checkpoint.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])

print("loaded checkpoints")

img = Image.open(image_dir)
img = img.convert('RGB')

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
])

input_image = transform(img).unsqueeze(0)

window_size = (1, 3, 256, 256)
stride = 64
output = sliding_window_inference(model, input_image, window_size, stride, threshold)

normalized_output = output
denormalized_output = normalized_output * torch.tensor([1.0, 1.0, 1.0]).view(1, 3, 1, 1) + torch.tensor([0.0, 0.0, 0.0]).view(1, 3, 1, 1)
torchvision.utils.save_image(denormalized_output, output_dir)


def main():
parser = argparse.ArgumentParser(description='Run sliding window inference on an image.')
parser.add_argument('image_dir', type=str, help='Path to the input image directory.')
parser.add_argument('output_dir', type=str, help='Path to the output directory.')
parser.add_argument('--threshold', type=float, default=0.5, help='threshold')

args = parser.parse_args()

run_sliding_window(args.image_dir, args.output_dir, args.threshold)

if __name__ == "__main__":
main()
Empty file added train.py
Empty file.

0 comments on commit 7c5afc2

Please sign in to comment.