Skip to content

Commit 9f7d304

Browse files
committed
move arange converter to ops_evaluators.py
1 parent cd158b6 commit 9f7d304

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import builtins
12
import logging
23
import operator
34
from typing import Dict, Sequence, Tuple, Union
@@ -23,7 +24,9 @@ def getitem_validator(getitem_node: Node) -> bool:
2324

2425
# TODO: Subsequent evaluators should be registered here with their own validators
2526
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
27+
@dynamo_tensorrt_converter(builtins.getattr)
2628
@dynamo_tensorrt_converter(torch.ops.aten.detach.default)
29+
@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step)
2730
def generic_evaluator(
2831
ctx: ConversionContext,
2932
target: Target,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArangeConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(0, 5, 1),
13+
(1, 5, 2),
14+
(3, 5, 3),
15+
(5, 0, -1),
16+
(5, 1, -2),
17+
(5, 3, -3),
18+
]
19+
)
20+
def test_arange(self, start, end, step):
21+
class Arange(nn.Module):
22+
def forward(self, x):
23+
return torch.ops.aten.arange.start_step(start, x.shape[0], step)
24+
25+
inputs = [torch.randn(end, 1)]
26+
self.run_test(
27+
Arange(),
28+
inputs,
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)