Skip to content

Commit 3497891

Browse files
Elias Ellisonfacebook-github-bot
Elias Ellison
authored andcommitted
add sorted keyword for lists and dicts (pytorch#23274)
Summary: Add `sorted` keyword to JIT for lists and dicts. This desugars to a list copy and a call to `list.sort()`. Since we don't have interfaces yet I implement it in terms of `list.sort()`. When we do we can re-visit implementing this op in a different manner. The test fails bc of a fix to specialized lists which is landing here: pytorch#23267 Ignore the first commit because it is formatting, plz use clang_format ppl :'( Pull Request resolved: pytorch#23274 Differential Revision: D16527323 Pulled By: eellison fbshipit-source-id: aed8faef23cb790b9af036cd6c1b9b1d7066345d
1 parent f0ebf76 commit 3497891

File tree

6 files changed

+149
-58
lines changed

6 files changed

+149
-58
lines changed

aten/src/ATen/core/interned_strings.h

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ namespace c10 {
9595
_(aten, _size_if_not_equal) \
9696
_(aten, _ncf_unsqueeze) \
9797
_(aten, warn) \
98+
_(aten, sorted) \
9899
_(aten, floordiv) \
99100
_(aten, __range_length) \
100101
_(aten, __derive_index) \
@@ -119,6 +120,7 @@ namespace c10 {
119120
_(aten, __not__) \
120121
_(aten, __is__) \
121122
_(aten, __isnot__) \
123+
_(aten, copy) \
122124
_(aten, copy_) \
123125
_(aten, t_) \
124126
_(aten, addbmm_) \

test/test_jit.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -16222,9 +16222,13 @@ def test_invalid_list_equality():
1622216222
def test_list_sort(self):
1622316223
template = dedent('''
1622416224
def func():
16225-
li = {list_create}
16226-
li.sort()
16227-
return li
16225+
li_1 = {list_create}
16226+
li_2 = {list_create}
16227+
li_3 = {list_create}
16228+
li_1.sort()
16229+
li_2.sort(reverse=True)
16230+
li_4 = sorted(li_3)
16231+
return li_1, li_2, li_3, li_4
1622816232
''')
1622916233

1623016234
lists = ["[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]",
@@ -17310,19 +17314,23 @@ def __lt__(self, other):
1731017314
def getVal(self):
1731117315
return self.x
1731217316

17313-
@torch.jit.script
1731417317
def test(li, reverse=False):
17315-
# type: (List[Foo], bool) -> List[int]
17318+
# type: (List[Foo], bool)
17319+
li_sorted = sorted(li)
17320+
ret_sorted = torch.jit.annotate(List[int], [])
17321+
for foo in li_sorted:
17322+
ret_sorted.append(foo.getVal())
17323+
1731617324
li.sort(reverse=reverse)
17317-
ret_list = torch.jit.annotate(List[int], [])
17325+
ret_sort = torch.jit.annotate(List[int], [])
1731817326
for foo in li:
17319-
ret_list.append(foo.getVal())
17320-
return ret_list
17327+
ret_sort.append(foo.getVal())
17328+
return ret_sorted, ret_sort
1732117329

17322-
self.assertEqual(test([Foo(2), Foo(1), Foo(3)]), [1, 2, 3])
17323-
self.assertEqual(test([Foo(2), Foo(1), Foo(3)], True), [3, 2, 1])
17324-
self.assertEqual(test([Foo(2)]), [2])
17325-
self.assertEqual(test([]), [])
17330+
self.checkScript(test, ([Foo(2), Foo(1), Foo(3)],))
17331+
self.checkScript(test, ([Foo(2), Foo(1), Foo(3)], True))
17332+
self.checkScript(test, ([Foo(2)],))
17333+
self.checkScript(test, ([],))
1732617334

1732717335
@torch.jit.script
1732817336
def test_list_no_reverse():
@@ -17332,6 +17340,14 @@ def test_list_no_reverse():
1733217340

1733317341
self.assertEqual(test_list_no_reverse(), 1)
1733417342

17343+
@torch.jit.script
17344+
def test_sorted_copies():
17345+
li = [Foo(3), Foo(1)]
17346+
li_sorted = sorted(li)
17347+
return li[0].getVal(), li_sorted[0].getVal()
17348+
17349+
self.assertEqual(test_sorted_copies(), (3, 1))
17350+
1733517351
with self.assertRaisesRegex(RuntimeError, "bool\' for argument \'reverse"):
1733617352
@torch.jit.script
1733717353
def test():

torch/csrc/jit/passes/python_print.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,9 @@ struct PythonPrintPass {
10071007
case prim::Print: {
10081008
printValueList(stmt, node->inputs(), "print(", ")");
10091009
} break;
1010+
case aten::sorted: {
1011+
printValueList(stmt, node->inputs(), "sorted(", ")");
1012+
} break;
10101013
case prim::TupleConstruct: {
10111014
if (auto qualname = node->output()
10121015
->type()

torch/csrc/jit/register_prim_ops.cpp

+106-37
Original file line numberDiff line numberDiff line change
@@ -1756,23 +1756,50 @@ int listSlice(Stack& stack) {
17561756

17571757
template <typename T>
17581758
int listSort(Stack& stack) {
1759+
bool reverse = pop(stack).toBool();
17591760
c10::List<T> list = pop(stack).to<c10::List<T>>();
1760-
std::sort(list.begin(), list.end(), [] (const T& a, const T& b) {
1761-
return a < b;
1761+
std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
1762+
return (a < b) ^ reverse;
17621763
});
17631764
return 0;
17641765
}
17651766

17661767
// Specialization for at::Tensor
17671768
template <>
17681769
int listSort<at::Tensor>(Stack& stack) {
1770+
bool reverse = pop(stack).toBool();
1771+
c10::List<at::Tensor> list = pop(stack).toTensorList();
1772+
std::sort(
1773+
list.begin(),
1774+
list.end(),
1775+
[reverse](const at::Tensor& a, const at::Tensor& b) {
1776+
return (a.lt(b).is_nonzero()) ^ reverse;
1777+
});
1778+
return 0;
1779+
}
1780+
1781+
template <typename T>
1782+
int listCopyAndSort(Stack& stack) {
1783+
c10::List<T> list = pop(stack).to<c10::List<T>>();
1784+
auto list_copied = list.copy();
1785+
std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
1786+
return a < b;
1787+
});
1788+
push(stack, list_copied);
1789+
return 0;
1790+
}
1791+
1792+
// Specialization for at::Tensor
1793+
template <>
1794+
int listCopyAndSort<at::Tensor>(Stack& stack) {
17691795
c10::List<at::Tensor> list = pop(stack).toTensorList();
17701796
std::sort(
17711797
list.begin(),
17721798
list.end(),
17731799
[](const at::Tensor& a, const at::Tensor& b) {
17741800
return a.lt(b).is_nonzero();
17751801
});
1802+
push(stack, list);
17761803
return 0;
17771804
}
17781805

@@ -2233,21 +2260,37 @@ RegisterOperators reg2({
22332260
CREATE_LIST_OPS("t", c10::List<IValue>),
22342261
#undef CREATE_LIST_OPS
22352262
Operator(
2236-
"aten::sort(int[](a!) self) -> ()",
2263+
"aten::sort(int[](a!) self, bool reverse=False) -> ()",
22372264
listSort<int64_t>,
22382265
aliasAnalysisFromSchema()),
22392266
Operator(
2240-
"aten::sort(float[](a!) self) -> ()",
2267+
"aten::sort(float[](a!) self, bool reverse=False) -> ()",
22412268
listSort<double>,
22422269
aliasAnalysisFromSchema()),
22432270
Operator(
2244-
"aten::sort(Tensor[](a!) self) -> ()",
2271+
"aten::sort(Tensor[](a!) self, bool reverse=False) -> ()",
22452272
listSort<at::Tensor>,
22462273
aliasAnalysisFromSchema()),
22472274
Operator(
2248-
"aten::sort(bool[](a!) self) -> ()",
2275+
"aten::sort(bool[](a!) self, bool reverse=False) -> ()",
22492276
listSort<bool>,
22502277
aliasAnalysisFromSchema()),
2278+
Operator(
2279+
"aten::sorted(int[](a) input) -> (int[])",
2280+
listCopyAndSort<int64_t>,
2281+
aliasAnalysisFromSchema()),
2282+
Operator(
2283+
"aten::sorted(float[](a) input) -> (float[])",
2284+
listCopyAndSort<double>,
2285+
aliasAnalysisFromSchema()),
2286+
Operator(
2287+
"aten::sorted(Tensor[](a) input) -> (Tensor[])",
2288+
listCopyAndSort<at::Tensor>,
2289+
aliasAnalysisFromSchema()),
2290+
Operator(
2291+
"aten::sorted(bool[](a) input) -> (bool[])",
2292+
listCopyAndSort<bool>,
2293+
aliasAnalysisFromSchema()),
22512294

22522295
Operator(
22532296
"aten::eq(int[] a, int[] b) -> bool",
@@ -2816,49 +2859,75 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
28162859
<< class_type->python_str() << " that "
28172860
<< "returns a bool";
28182861
} else {
2819-
error_str
2820-
<< "Input to list sort must be of Tensors, ints, floats, bools or "
2821-
<< "a User Defined Class that defines the __lt__ compare method"
2822-
<< ", got list of " << list_element_type->python_str() << "\n";
2862+
error_str << "Input to " << node->kind().toUnqualString()
2863+
<< "must be of Tensors, ints, floats, bools or "
2864+
<< "a User Defined Class that defines the __lt__ compare method"
2865+
<< ", got list of " << list_element_type->python_str() << "\n";
28232866
}
28242867

28252868
auto error_msg = script::ErrorReport(node->sourceRange());
28262869
error_msg << error_str.str();
28272870
throw error_msg;
28282871
}
28292872

2873+
Operation sort_op(
2874+
Function* lt_func,
2875+
bool has_reverse_arg,
2876+
bool copy_return_list) {
2877+
return [lt_func, has_reverse_arg, copy_return_list](Stack& stack) {
2878+
bool reverse = has_reverse_arg ? pop(stack).toBool() : false;
2879+
auto g_list = pop(stack).toGenericList();
2880+
if (copy_return_list) {
2881+
g_list = g_list.copy();
2882+
}
2883+
Stack sort_stack;
2884+
std::sort(
2885+
g_list.begin(),
2886+
g_list.end(),
2887+
[lt_func, reverse, &sort_stack](IValue a, IValue b) -> bool {
2888+
// FBCode errors without this check - "strict weak ordering"
2889+
// TODO: remove when possible, since it just slows down
2890+
// sorting and doesn't do anything useful
2891+
if (a.isSameIdentity(b)) {
2892+
return false;
2893+
}
2894+
sort_stack.push_back(a);
2895+
sort_stack.push_back(b);
2896+
lt_func->run(sort_stack);
2897+
return pop(sort_stack).toBool() ^ reverse;
2898+
});
2899+
if (copy_return_list) {
2900+
push(stack, g_list);
2901+
}
2902+
return 0;
2903+
};
2904+
}
2905+
2906+
Function* getLtFuncFromListOfClassTypes(const Node* node) {
2907+
const auto list_type = node->inputs().at(0)->type()->expect<ListType>();
2908+
checkSortSchema(node, list_type->getElementType());
2909+
const auto elem = list_type->getElementType()->expect<ClassType>();
2910+
return elem->getMethod("__lt__");
2911+
}
2912+
28302913
// NB: this must be registered after the other aten::sort operators
28312914
RegisterOperators regSort({
2915+
Operator(
2916+
"aten::sorted(t[](a) self) -> (t[])",
2917+
[](const Node* node) {
2918+
return sort_op(
2919+
getLtFuncFromListOfClassTypes(node),
2920+
/*has_reverse_arg*/ false,
2921+
/*copy_return_list*/ true);
2922+
},
2923+
aliasAnalysisFromSchema()),
28322924
Operator(
28332925
"aten::sort(t[](a!) self, bool reverse=False) -> ()",
28342926
[](const Node* node) {
2835-
const auto list_type =
2836-
node->inputs().at(0)->type()->expect<ListType>();
2837-
checkSortSchema(node, list_type->getElementType());
2838-
const auto elem = list_type->getElementType()->expect<ClassType>();
2839-
auto func = elem->getMethod("__lt__");
2840-
return [func](Stack& stack) {
2841-
bool reverse = pop(stack).toBool();
2842-
auto g_list = pop(stack).toGenericList();
2843-
Stack sort_stack;
2844-
std::sort(
2845-
g_list.begin(),
2846-
g_list.end(),
2847-
[func, reverse, &sort_stack](
2848-
IValue a, IValue b) -> bool {
2849-
// FBCode errors without this check - "strict weak ordering"
2850-
// TODO: remove when possible, since it just slows down
2851-
// sorting and doesn't do anything useful
2852-
if (a.isSameIdentity(b)) {
2853-
return false;
2854-
}
2855-
sort_stack.push_back(a);
2856-
sort_stack.push_back(b);
2857-
func->run(sort_stack);
2858-
return pop(sort_stack).toBool() ^ reverse;
2859-
});
2860-
return 0;
2861-
};
2927+
return sort_op(
2928+
getLtFuncFromListOfClassTypes(node),
2929+
/*has_reverse_arg*/ true,
2930+
/*copy_return_list*/ false);
28622931
},
28632932
aliasAnalysisFromSchema()),
28642933
});

torch/csrc/jit/script/compiler.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ struct Environment {
402402
{"enumerate", std::make_shared<IterableValue>(prim::enumerate)},
403403
{"rangelist",
404404
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
405+
{"sorted",
406+
std::make_shared<BuiltinFunction>(aten::sorted, at::nullopt)},
405407
};
406408
auto it = globals.find(ident);
407409
if (it != globals.end()) {

torch/csrc/jit/script/sugared_value.h

+8-9
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,7 @@ using SugaredValuePtr = std::shared_ptr<SugaredValue>;
362362
// builtins operators and functions that call a method if it exists
363363
// on a class type, like 'len(x)' and 'x + y'
364364
struct TORCH_API MagicMethod : public SugaredValue {
365-
MagicMethod(
366-
std::string desugared_name,
367-
SugaredValuePtr base)
365+
MagicMethod(std::string desugared_name, SugaredValuePtr base)
368366
: base_value_(std::move(base)),
369367
desugared_name_(std::move(desugared_name)) {}
370368

@@ -443,7 +441,7 @@ struct TORCH_API IsInstanceValue : SugaredValue {
443441

444442
// matched against for special handling of range expressions
445443
struct TORCH_API RangeValue : SugaredValue {
446-
RangeValue(const SourceRange& loc, Function&m, std::vector<Value*> inputs);
444+
RangeValue(const SourceRange& loc, Function& m, std::vector<Value*> inputs);
447445
std::string kind() const override {
448446
return "range";
449447
}
@@ -463,25 +461,26 @@ struct TORCH_API RangeValue : SugaredValue {
463461

464462
// matched against for special handling of iterables like zip(), enumerate()
465463
struct TORCH_API IterableValue : SugaredValue {
466-
IterableValue(Symbol symbol): symbol_(symbol) {}
464+
IterableValue(Symbol symbol) : symbol_(symbol) {}
467465
std::string kind() const override {
468466
return "iterable";
469467
}
470468
Symbol symbol_;
471469
};
472470

473-
// Specialized Tree structure to matched against for special handling
471+
// Specialized Tree structure to matched against for special handling
474472
// of builtin functions iterables expressions like zip(), enumerate(), etc.
475473
// zip and enumerate can be modeled as a tree of SimpleValue/RangeValue:
476474
// zip(x, y) -> (x, y) with tuple assignment to each loop target
477475
// enumerate(x) -> (range(0, math.inf, 1), x)
478476
// So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be:
479477
// (a, (range(0, math.inf, 1), b), range(0, 100))
480-
// We use those base iterables to fill in the loop information like max_trip_count
481-
// and set the value table for loop targets
478+
// We use those base iterables to fill in the loop information like
479+
// max_trip_count and set the value table for loop targets
482480
struct TORCH_API IterableTree : SugaredValue {
483481
IterableTree() = default;
484-
IterableTree(const std::vector<SugaredValuePtr> children): children_(std::move(children)) {}
482+
IterableTree(const std::vector<SugaredValuePtr> children)
483+
: children_(std::move(children)) {}
485484
std::string kind() const override {
486485
return "iterabletree";
487486
}

0 commit comments

Comments
 (0)