From ed69dbeb246d8fe340a8b55b82c0a1c435a2e996 Mon Sep 17 00:00:00 2001 From: Nico Van den Hooff Date: Thu, 21 Apr 2022 15:02:34 -0700 Subject: [PATCH] refactor: minor edits to uitls --- api/ml/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/api/ml/utils.py b/api/ml/utils.py index 4cbd5ab..a96a9c7 100644 --- a/api/ml/utils.py +++ b/api/ml/utils.py @@ -65,7 +65,7 @@ def bytes_to_b64(img_bytes): return img_b64 -def bytes_to_tensor(img_bytes, batch=False): +def bytes_to_tensor(img_bytes): """Transforms an image from bytes to PyTorch Tensor. Parameters @@ -81,16 +81,12 @@ def bytes_to_tensor(img_bytes, batch=False): img_tensor : torch.Tensor PyTorch Tensor representation of the image. """ - # TODO: update size at the end with final img_size = (256, 256) img = Image.open(io.BytesIO(img_bytes)) img_transforms = transforms.Compose( [transforms.Resize(img_size), transforms.ToTensor()] ) - img_tensor = img_transforms(img) - - if not batch: - img_tensor = img_tensor.unsqueeze(0) + img_tensor = img_transforms(img).unsqueeze(0) return img_tensor @@ -125,10 +121,10 @@ def fig_to_bytes(fig): Returns ------- - img_bytes : bytes + fig_bytes : bytes Byte representation of the figure. """ buffer = io.BytesIO() fig.savefig(buffer) - img_bytes = buffer.getvalue() - return img_bytes + fig_bytes = buffer.getvalue() + return fig_bytes