Skip to content

Commit

Permalink
Fix exporting ONNX
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 1, 2022
1 parent 9652ae8 commit 411145a
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions yolort/models/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _onnx_batch_images(self, images: List[Tensor]) -> Tensor:
stride = self.size_divisible
max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size = tuple(max_size)

# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
Expand All @@ -183,14 +182,20 @@ def _onnx_batch_images(self, images: List[Tensor]) -> Tensor:

img_h, img_w = img.shape[-2:]

dh = self.new_shape[0] - img_w
dh = max_size[2] - img_w
dh = dh % stride
dh = dh / 2 # divide padding into 2 sides

dw = self.new_shape[1] - img_h
dw = max_size[1] - img_h
dw = dw % stride
dw = dw / 2
padding = int(round(dh - 0.1)), int(round(dh + 0.1)), int(round(dw - 0.1)), int(round(dw + 0.1))

padding = (
int(torch.round(dh - 0.1)),
int(torch.round(dh + 0.1)),
int(torch.round(dw - 0.1)),
int(torch.round(dw + 0.1)),
)
padded_img = F.pad(img, padding, value=self.fill_color)

padded_imgs.append(padded_img)
Expand Down Expand Up @@ -265,10 +270,10 @@ def _resize_image_and_masks(
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:

im_shape = list(image.shape[-2:])
r = min(new_shape[0] / im_shape[0], new_shape[1] / im_shape[1])
im_shape = torch.tensor(image.shape[-2:])
ratio = torch.min(new_shape[0] / im_shape[0], new_shape[1] / im_shape[1])

new_unpad = int(round(im_shape[0] * r)), int(round(im_shape[1] * r))
new_unpad = int(torch.round(im_shape[0] * ratio)), int(torch.round(im_shape[1] * ratio))
image = F.interpolate(image[None], size=new_unpad, mode="bilinear", align_corners=False)[0]

if target is None:
Expand Down

0 comments on commit 411145a

Please sign in to comment.