Skip to content

↔ [Converter] Add support for aten::mul(Bool Tensor, Float Tensor) in TRTorch #341

Closed
@developer0hye

Description

@developer0hye

TRTorch cannot compile the below function.

@torch.jit.export
def _nms(self, heatmap):
    heatmap_max = nn.functional.max_pool2d(heatmap, self.kernel_size, 1,
                                           self.kernel_size // 2)
    keep = (heatmap_max == heatmap).float()
    return keep * heatmap

When I compiled model including this function with TRTorch, TRTorch recognize keep as Bool Tensor.

The printed error is same with:

[TRTorch Conversion Context] - %heatmap.1 : Tensor = aten::mul(%471, %469) # C:\Users\yhkwon\Documents\project\0215\project\centernet.py:219:15: operation PROD has incompatible input types Bool and Float

It seems that TRTorch cannot recognize type casting function .float().

  • Function Schema:

aten::mul(Tensor, Tensor)

  • Original PyTorch API:

  • Relevant TensorRT Documentation:

Alternatives

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions