Skip to content

Commit

Permalink
Try converting aten.repeat to ttnn.repeat
Browse files Browse the repository at this point in the history
Error message:
```
E       RuntimeError: TT_FATAL @ /home/runner/work/tt-metal/tt-metal/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp:37: (input_tensor.get_legacy_shape()[dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0
E       info:
E       Current repeat implementation requires aligned last dim when repeating on last dim
```
  • Loading branch information
jdh8 committed Sep 12, 2024
1 parent 79c0432 commit 769b86c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 0 additions & 1 deletion tests/lowering/tensor_manipulation/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def forward(self, x, sizes):
return x.repeat(sizes)


@pytest.mark.xfail(reason="lowering issue (#67)")
@pytest.mark.parametrize(
"input_shape, sizes",
[((4, 4), (3, 2))],
Expand Down
3 changes: 3 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ def call_function(self, target, args, kwargs):
if target == torch.ops.aten.permute.default:
return self.call_function_prop_meta(ttnn.permute, args, kwargs)

if target == torch.ops.aten.repeat.default:
return self.call_function_prop_meta(target_wrappers.repeat, args, kwargs)

if target == torch.ops.aten.view.default:
# aten.reshape is more stable if the input nodes have changed
return self.call_function_prop_meta(torch.ops.aten.reshape.default, args, kwargs)
Expand Down

0 comments on commit 769b86c

Please sign in to comment.