@@ -1756,23 +1756,50 @@ int listSlice(Stack& stack) {
1756
1756
1757
1757
template <typename T>
1758
1758
int listSort (Stack& stack) {
1759
+ bool reverse = pop (stack).toBool ();
1759
1760
c10::List<T> list = pop (stack).to <c10::List<T>>();
1760
- std::sort (list.begin (), list.end (), [] (const T& a, const T& b) {
1761
- return a < b;
1761
+ std::sort (list.begin (), list.end (), [reverse] (const T& a, const T& b) {
1762
+ return ( a < b) ^ reverse ;
1762
1763
});
1763
1764
return 0 ;
1764
1765
}
1765
1766
1766
1767
// Specialization for at::Tensor
1767
1768
template <>
1768
1769
int listSort<at::Tensor>(Stack& stack) {
1770
+ bool reverse = pop (stack).toBool ();
1771
+ c10::List<at::Tensor> list = pop (stack).toTensorList ();
1772
+ std::sort (
1773
+ list.begin (),
1774
+ list.end (),
1775
+ [reverse](const at::Tensor& a, const at::Tensor& b) {
1776
+ return (a.lt (b).is_nonzero ()) ^ reverse;
1777
+ });
1778
+ return 0 ;
1779
+ }
1780
+
1781
+ template <typename T>
1782
+ int listCopyAndSort (Stack& stack) {
1783
+ c10::List<T> list = pop (stack).to <c10::List<T>>();
1784
+ auto list_copied = list.copy ();
1785
+ std::sort (list_copied.begin (), list_copied.end (), [](const T& a, const T& b) {
1786
+ return a < b;
1787
+ });
1788
+ push (stack, list_copied);
1789
+ return 0 ;
1790
+ }
1791
+
1792
+ // Specialization for at::Tensor
1793
+ template <>
1794
+ int listCopyAndSort<at::Tensor>(Stack& stack) {
1769
1795
c10::List<at::Tensor> list = pop (stack).toTensorList ();
1770
1796
std::sort (
1771
1797
list.begin (),
1772
1798
list.end (),
1773
1799
[](const at::Tensor& a, const at::Tensor& b) {
1774
1800
return a.lt (b).is_nonzero ();
1775
1801
});
1802
+ push (stack, list);
1776
1803
return 0 ;
1777
1804
}
1778
1805
@@ -2233,21 +2260,37 @@ RegisterOperators reg2({
2233
2260
CREATE_LIST_OPS (" t" , c10::List<IValue>),
2234
2261
#undef CREATE_LIST_OPS
2235
2262
Operator (
2236
- " aten::sort(int[](a!) self) -> ()" ,
2263
+ " aten::sort(int[](a!) self, bool reverse=False ) -> ()" ,
2237
2264
listSort<int64_t >,
2238
2265
aliasAnalysisFromSchema ()),
2239
2266
Operator (
2240
- " aten::sort(float[](a!) self) -> ()" ,
2267
+ " aten::sort(float[](a!) self, bool reverse=False ) -> ()" ,
2241
2268
listSort<double >,
2242
2269
aliasAnalysisFromSchema ()),
2243
2270
Operator (
2244
- " aten::sort(Tensor[](a!) self) -> ()" ,
2271
+ " aten::sort(Tensor[](a!) self, bool reverse=False ) -> ()" ,
2245
2272
listSort<at::Tensor>,
2246
2273
aliasAnalysisFromSchema ()),
2247
2274
Operator (
2248
- " aten::sort(bool[](a!) self) -> ()" ,
2275
+ " aten::sort(bool[](a!) self, bool reverse=False ) -> ()" ,
2249
2276
listSort<bool >,
2250
2277
aliasAnalysisFromSchema ()),
2278
+ Operator (
2279
+ " aten::sorted(int[](a) input) -> (int[])" ,
2280
+ listCopyAndSort<int64_t >,
2281
+ aliasAnalysisFromSchema ()),
2282
+ Operator (
2283
+ " aten::sorted(float[](a) input) -> (float[])" ,
2284
+ listCopyAndSort<double >,
2285
+ aliasAnalysisFromSchema ()),
2286
+ Operator (
2287
+ " aten::sorted(Tensor[](a) input) -> (Tensor[])" ,
2288
+ listCopyAndSort<at::Tensor>,
2289
+ aliasAnalysisFromSchema ()),
2290
+ Operator (
2291
+ " aten::sorted(bool[](a) input) -> (bool[])" ,
2292
+ listCopyAndSort<bool >,
2293
+ aliasAnalysisFromSchema ()),
2251
2294
2252
2295
Operator (
2253
2296
" aten::eq(int[] a, int[] b) -> bool" ,
@@ -2816,49 +2859,75 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
2816
2859
<< class_type->python_str () << " that "
2817
2860
<< " returns a bool" ;
2818
2861
} else {
2819
- error_str
2820
- << " Input to list sort must be of Tensors, ints, floats, bools or "
2821
- << " a User Defined Class that defines the __lt__ compare method"
2822
- << " , got list of " << list_element_type->python_str () << " \n " ;
2862
+ error_str << " Input to " << node-> kind (). toUnqualString ()
2863
+ << " must be of Tensors, ints, floats, bools or "
2864
+ << " a User Defined Class that defines the __lt__ compare method"
2865
+ << " , got list of " << list_element_type->python_str () << " \n " ;
2823
2866
}
2824
2867
2825
2868
auto error_msg = script::ErrorReport (node->sourceRange ());
2826
2869
error_msg << error_str.str ();
2827
2870
throw error_msg;
2828
2871
}
2829
2872
2873
+ Operation sort_op (
2874
+ Function* lt_func,
2875
+ bool has_reverse_arg,
2876
+ bool copy_return_list) {
2877
+ return [lt_func, has_reverse_arg, copy_return_list](Stack& stack) {
2878
+ bool reverse = has_reverse_arg ? pop (stack).toBool () : false ;
2879
+ auto g_list = pop (stack).toGenericList ();
2880
+ if (copy_return_list) {
2881
+ g_list = g_list.copy ();
2882
+ }
2883
+ Stack sort_stack;
2884
+ std::sort (
2885
+ g_list.begin (),
2886
+ g_list.end (),
2887
+ [lt_func, reverse, &sort_stack](IValue a, IValue b) -> bool {
2888
+ // FBCode errors without this check - "strict weak ordering"
2889
+ // TODO: remove when possible, since it just slows down
2890
+ // sorting and doesn't do anything useful
2891
+ if (a.isSameIdentity (b)) {
2892
+ return false ;
2893
+ }
2894
+ sort_stack.push_back (a);
2895
+ sort_stack.push_back (b);
2896
+ lt_func->run (sort_stack);
2897
+ return pop (sort_stack).toBool () ^ reverse;
2898
+ });
2899
+ if (copy_return_list) {
2900
+ push (stack, g_list);
2901
+ }
2902
+ return 0 ;
2903
+ };
2904
+ }
2905
+
2906
+ Function* getLtFuncFromListOfClassTypes (const Node* node) {
2907
+ const auto list_type = node->inputs ().at (0 )->type ()->expect <ListType>();
2908
+ checkSortSchema (node, list_type->getElementType ());
2909
+ const auto elem = list_type->getElementType ()->expect <ClassType>();
2910
+ return elem->getMethod (" __lt__" );
2911
+ }
2912
+
2830
2913
// NB: this must be registered after the other aten::sort operators
2831
2914
RegisterOperators regSort ({
2915
+ Operator (
2916
+ " aten::sorted(t[](a) self) -> (t[])" ,
2917
+ [](const Node* node) {
2918
+ return sort_op (
2919
+ getLtFuncFromListOfClassTypes (node),
2920
+ /* has_reverse_arg*/ false ,
2921
+ /* copy_return_list*/ true );
2922
+ },
2923
+ aliasAnalysisFromSchema ()),
2832
2924
Operator (
2833
2925
" aten::sort(t[](a!) self, bool reverse=False) -> ()" ,
2834
2926
[](const Node* node) {
2835
- const auto list_type =
2836
- node->inputs ().at (0 )->type ()->expect <ListType>();
2837
- checkSortSchema (node, list_type->getElementType ());
2838
- const auto elem = list_type->getElementType ()->expect <ClassType>();
2839
- auto func = elem->getMethod (" __lt__" );
2840
- return [func](Stack& stack) {
2841
- bool reverse = pop (stack).toBool ();
2842
- auto g_list = pop (stack).toGenericList ();
2843
- Stack sort_stack;
2844
- std::sort (
2845
- g_list.begin (),
2846
- g_list.end (),
2847
- [func, reverse, &sort_stack](
2848
- IValue a, IValue b) -> bool {
2849
- // FBCode errors without this check - "strict weak ordering"
2850
- // TODO: remove when possible, since it just slows down
2851
- // sorting and doesn't do anything useful
2852
- if (a.isSameIdentity (b)) {
2853
- return false ;
2854
- }
2855
- sort_stack.push_back (a);
2856
- sort_stack.push_back (b);
2857
- func->run (sort_stack);
2858
- return pop (sort_stack).toBool () ^ reverse;
2859
- });
2860
- return 0 ;
2861
- };
2927
+ return sort_op (
2928
+ getLtFuncFromListOfClassTypes (node),
2929
+ /* has_reverse_arg*/ true ,
2930
+ /* copy_return_list*/ false );
2862
2931
},
2863
2932
aliasAnalysisFromSchema ()),
2864
2933
});
0 commit comments