Skip to content

Commit

Permalink
Fix bug with LegalizeLayoutTranform which added duplicate ops (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
trevor-m authored Jan 30, 2020
1 parent 98b8ca4 commit 3188a68
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ def visit_call(self, expr):
src_layout = expr.attrs['src_layout']
dst_layout = expr.attrs['dst_layout']
if src_layout == "NCHW" and dst_layout == "NHWC":
return relay.transpose(visit, axes=[0, 2, 3, 1])
return relay.transpose(visit.args[0], axes=[0, 2, 3, 1])
elif src_layout == "NHWC" and dst_layout == "NCHW":
return relay.transpose(visit, axes=[0, 3, 1, 2])
return relay.transpose(visit.args[0], axes=[0, 3, 1, 2])
elif src_layout == "HWIO" and dst_layout == "OIHW":
return relay.transpose(visit, axes=[3, 2, 0, 1])
return relay.transpose(visit.args[0], axes=[3, 2, 0, 1])
elif src_layout == "HWOI" and dst_layout == "OIHW":
return relay.transpose(visit, axes=[2, 3, 0, 1])
# may be uneeded
return relay.transpose(visit.args[0], axes=[2, 3, 0, 1])
elif src_layout == "HWIO" and dst_layout == "IOHW":
return relay.transpose(visit, axes=[2, 3, 0, 1])
return relay.transpose(visit.args[0], axes=[2, 3, 0, 1])
return visit

class RemoveDropout(ExprMutator):
Expand Down

0 comments on commit 3188a68

Please sign in to comment.