This repository was archived by the owner on Apr 28, 2023. It is now read-only.
This repository was archived by the owner on Apr 28, 2023. It is now read-only.
Allow using scalars for bounds inference #73
Open
Description
We can't use scalar inputs in the bounds inference right now. So for example:
LANG="""
def avgpool(float(B, C, H, W) input, float kH, float kW, float sH, float sW) -> (output) {
output(b, c, h, w) += input(b, c, h * sH + kh, w * sW + kw) where kh in 0:kH, kw in 0:kW
}
"""
LANG="""
def avgpool(float(B, C, H, W) input, float kH, float kW) -> (output) {
output(b, c, h, w) += input(b, c, h + kh, w + kw) where kh in 0:kH, kw in 0:kW
}
"""
avgpool = tc.define(LANG, name="avgpool")
inp = torch.ones(1, 1, 4, 4).cuda()
kH = torch.randn(1).fill_(2.0).cuda()
kW = torch.randn(1).fill_(2.0).cuda()
sH = torch.randn(1).fill_(1.0).cuda()
sW = torch.randn(1).fill_(1.0).cuda()
out = avgpool(inp, kH, kW, sH, sW)
this will fail.
The workaround right now is to do proper substitution for those scalars in the TC before passing them to the backend.