Closed
Description
TRTorch cannot compile the below function.
@torch.jit.export
def _nms(self, heatmap):
heatmap_max = nn.functional.max_pool2d(heatmap, self.kernel_size, 1,
self.kernel_size // 2)
keep = (heatmap_max == heatmap).float()
return keep * heatmap
When I compiled model including this function with TRTorch, TRTorch recognize keep
as Bool Tensor
.
The printed error is same with:
[TRTorch Conversion Context] - %heatmap.1 : Tensor = aten::mul(%471, %469) # C:\Users\yhkwon\Documents\project\0215\project\centernet.py:219:15: operation PROD has incompatible input types Bool and Float
It seems that TRTorch cannot recognize type casting function .float()
.
- Function Schema:
aten::mul(Tensor, Tensor)
-
Original PyTorch API:
-
Relevant TensorRT Documentation: