Skip to content

Commit

Permalink
eval changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fvviz committed Feb 1, 2024
1 parent 6c42fee commit a60819e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

checkpoints/
experimental/
43 changes: 23 additions & 20 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
import torchvision
import torch
import numpy as np
from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2

from model import UNET
from utils import load_checkpoint
import argparse

def process_image(image_path, target_size= (256, 256)):
val_transforms = A.Compose(
[
A.Resize(height=256, width=256),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
resize_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(target_size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
)
])

uneven_image = Image.open(image_path)
even_image = val_transforms(uneven_image)
uneven_image = uneven_image.convert("RGB")
even_image = resize_transform(uneven_image)
return even_image

def create_mask(image_path, out_mask_path, device='cpu'):
def create_mask(image_path, out_mask_path, device='cpu', checkpoint_path = 'checkpoints/best_checkpoint.pth.tar'):
img_tensor = process_image(image_path)
img_tensor.to(device)
img_tensor = img_tensor.unsqueeze(0)

model = UNET().to(device)
model = UNET()
if device!= 'gpu':
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
else:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])

model.eval()
with torch.no_grad():
mask_raw = model.predict(img_tensor)
mask = torch.sigmoid(mask_raw)
mask = torch.sigmoid(model(img_tensor))
mask = (mask>0.5).float()
torchvision.utils.save_image(mask, out_mask_path)
model.train()
Expand Down

0 comments on commit a60819e

Please sign in to comment.