Skip to content

Commit 359fe9c

Browse files
authored
Update img_plant_controller.py
1 parent dea943a commit 359fe9c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

ai/img_plant_controller.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,23 @@ def decode_image(base64_image):
2222

2323
def image_loader(image_name):
2424
"""load image, returns cuda tensor"""
25-
imsize = 256
26-
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])
25+
imsize = 224
26+
loader = transforms.Compose([transforms.Resize([224, 224]),
27+
transforms.ToTensor(),
28+
])
2729
image = Image.open(image_name)
2830
image = loader(image).float()
2931
image = Variable(image)
3032
image = image.unsqueeze(0) # this is for VGG, may not be needed for ResNet
3133
return image.cpu() # assumes that you're using GPU
3234

3335
def guess_type(image_path, model_path, trainloader):
36+
to_pil = transforms.ToPILImage()
3437
image = image_loader(image_path)
38+
image = to_pil(image)
3539
model = model_path
3640
out = model(image)
41+
model.eval()
3742
index = out.data.cpu().numpy().argmax()
3843
return trainloader.dataset.classes[index]
3944

@@ -58,6 +63,8 @@ def run_img_guessing(base64_image):
5863

5964
imgFilePath = decode_image(base64_image)
6065
model = torch.load(modelPath)
66+
model.eval()
67+
6168
plantName = guess_type(imgFilePath, model, trainloader)
6269

6370
return plantName

0 commit comments

Comments
 (0)