@@ -1756,23 +1756,50 @@ int listSlice(Stack& stack) {
17561756
17571757template <typename T>
17581758int listSort (Stack& stack) {
1759+ bool reverse = pop (stack).toBool ();
17591760 c10::List<T> list = pop (stack).to <c10::List<T>>();
1760- std::sort (list.begin (), list.end (), [] (const T& a, const T& b) {
1761- return a < b;
1761+ std::sort (list.begin (), list.end (), [reverse] (const T& a, const T& b) {
1762+ return ( a < b) ^ reverse ;
17621763 });
17631764 return 0 ;
17641765}
17651766
17661767// Specialization for at::Tensor
17671768template <>
17681769int listSort<at::Tensor>(Stack& stack) {
1770+ bool reverse = pop (stack).toBool ();
1771+ c10::List<at::Tensor> list = pop (stack).toTensorList ();
1772+ std::sort (
1773+ list.begin (),
1774+ list.end (),
1775+ [reverse](const at::Tensor& a, const at::Tensor& b) {
1776+ return (a.lt (b).is_nonzero ()) ^ reverse;
1777+ });
1778+ return 0 ;
1779+ }
1780+
1781+ template <typename T>
1782+ int listCopyAndSort (Stack& stack) {
1783+ c10::List<T> list = pop (stack).to <c10::List<T>>();
1784+ auto list_copied = list.copy ();
1785+ std::sort (list_copied.begin (), list_copied.end (), [](const T& a, const T& b) {
1786+ return a < b;
1787+ });
1788+ push (stack, list_copied);
1789+ return 0 ;
1790+ }
1791+
1792+ // Specialization for at::Tensor
1793+ template <>
1794+ int listCopyAndSort<at::Tensor>(Stack& stack) {
17691795 c10::List<at::Tensor> list = pop (stack).toTensorList ();
17701796 std::sort (
17711797 list.begin (),
17721798 list.end (),
17731799 [](const at::Tensor& a, const at::Tensor& b) {
17741800 return a.lt (b).is_nonzero ();
17751801 });
1802+ push (stack, list);
17761803 return 0 ;
17771804}
17781805
@@ -2233,21 +2260,37 @@ RegisterOperators reg2({
22332260 CREATE_LIST_OPS (" t" , c10::List<IValue>),
22342261#undef CREATE_LIST_OPS
22352262 Operator (
2236- " aten::sort(int[](a!) self) -> ()" ,
2263+ " aten::sort(int[](a!) self, bool reverse=False ) -> ()" ,
22372264 listSort<int64_t >,
22382265 aliasAnalysisFromSchema ()),
22392266 Operator (
2240- " aten::sort(float[](a!) self) -> ()" ,
2267+ " aten::sort(float[](a!) self, bool reverse=False ) -> ()" ,
22412268 listSort<double >,
22422269 aliasAnalysisFromSchema ()),
22432270 Operator (
2244- " aten::sort(Tensor[](a!) self) -> ()" ,
2271+ " aten::sort(Tensor[](a!) self, bool reverse=False ) -> ()" ,
22452272 listSort<at::Tensor>,
22462273 aliasAnalysisFromSchema ()),
22472274 Operator (
2248- " aten::sort(bool[](a!) self) -> ()" ,
2275+ " aten::sort(bool[](a!) self, bool reverse=False ) -> ()" ,
22492276 listSort<bool >,
22502277 aliasAnalysisFromSchema ()),
2278+ Operator (
2279+ " aten::sorted(int[](a) input) -> (int[])" ,
2280+ listCopyAndSort<int64_t >,
2281+ aliasAnalysisFromSchema ()),
2282+ Operator (
2283+ " aten::sorted(float[](a) input) -> (float[])" ,
2284+ listCopyAndSort<double >,
2285+ aliasAnalysisFromSchema ()),
2286+ Operator (
2287+ " aten::sorted(Tensor[](a) input) -> (Tensor[])" ,
2288+ listCopyAndSort<at::Tensor>,
2289+ aliasAnalysisFromSchema ()),
2290+ Operator (
2291+ " aten::sorted(bool[](a) input) -> (bool[])" ,
2292+ listCopyAndSort<bool >,
2293+ aliasAnalysisFromSchema ()),
22512294
22522295 Operator (
22532296 " aten::eq(int[] a, int[] b) -> bool" ,
@@ -2816,49 +2859,75 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
28162859 << class_type->python_str () << " that "
28172860 << " returns a bool" ;
28182861 } else {
2819- error_str
2820- << " Input to list sort must be of Tensors, ints, floats, bools or "
2821- << " a User Defined Class that defines the __lt__ compare method"
2822- << " , got list of " << list_element_type->python_str () << " \n " ;
2862+ error_str << " Input to " << node-> kind (). toUnqualString ()
2863+ << " must be of Tensors, ints, floats, bools or "
2864+ << " a User Defined Class that defines the __lt__ compare method"
2865+ << " , got list of " << list_element_type->python_str () << " \n " ;
28232866 }
28242867
28252868 auto error_msg = script::ErrorReport (node->sourceRange ());
28262869 error_msg << error_str.str ();
28272870 throw error_msg;
28282871}
28292872
2873+ Operation sort_op (
2874+ Function* lt_func,
2875+ bool has_reverse_arg,
2876+ bool copy_return_list) {
2877+ return [lt_func, has_reverse_arg, copy_return_list](Stack& stack) {
2878+ bool reverse = has_reverse_arg ? pop (stack).toBool () : false ;
2879+ auto g_list = pop (stack).toGenericList ();
2880+ if (copy_return_list) {
2881+ g_list = g_list.copy ();
2882+ }
2883+ Stack sort_stack;
2884+ std::sort (
2885+ g_list.begin (),
2886+ g_list.end (),
2887+ [lt_func, reverse, &sort_stack](IValue a, IValue b) -> bool {
2888+ // FBCode errors without this check - "strict weak ordering"
2889+ // TODO: remove when possible, since it just slows down
2890+ // sorting and doesn't do anything useful
2891+ if (a.isSameIdentity (b)) {
2892+ return false ;
2893+ }
2894+ sort_stack.push_back (a);
2895+ sort_stack.push_back (b);
2896+ lt_func->run (sort_stack);
2897+ return pop (sort_stack).toBool () ^ reverse;
2898+ });
2899+ if (copy_return_list) {
2900+ push (stack, g_list);
2901+ }
2902+ return 0 ;
2903+ };
2904+ }
2905+
2906+ Function* getLtFuncFromListOfClassTypes (const Node* node) {
2907+ const auto list_type = node->inputs ().at (0 )->type ()->expect <ListType>();
2908+ checkSortSchema (node, list_type->getElementType ());
2909+ const auto elem = list_type->getElementType ()->expect <ClassType>();
2910+ return elem->getMethod (" __lt__" );
2911+ }
2912+
28302913// NB: this must be registered after the other aten::sort operators
28312914RegisterOperators regSort ({
2915+ Operator (
2916+ " aten::sorted(t[](a) self) -> (t[])" ,
2917+ [](const Node* node) {
2918+ return sort_op (
2919+ getLtFuncFromListOfClassTypes (node),
2920+ /* has_reverse_arg*/ false ,
2921+ /* copy_return_list*/ true );
2922+ },
2923+ aliasAnalysisFromSchema ()),
28322924 Operator (
28332925 " aten::sort(t[](a!) self, bool reverse=False) -> ()" ,
28342926 [](const Node* node) {
2835- const auto list_type =
2836- node->inputs ().at (0 )->type ()->expect <ListType>();
2837- checkSortSchema (node, list_type->getElementType ());
2838- const auto elem = list_type->getElementType ()->expect <ClassType>();
2839- auto func = elem->getMethod (" __lt__" );
2840- return [func](Stack& stack) {
2841- bool reverse = pop (stack).toBool ();
2842- auto g_list = pop (stack).toGenericList ();
2843- Stack sort_stack;
2844- std::sort (
2845- g_list.begin (),
2846- g_list.end (),
2847- [func, reverse, &sort_stack](
2848- IValue a, IValue b) -> bool {
2849- // FBCode errors without this check - "strict weak ordering"
2850- // TODO: remove when possible, since it just slows down
2851- // sorting and doesn't do anything useful
2852- if (a.isSameIdentity (b)) {
2853- return false ;
2854- }
2855- sort_stack.push_back (a);
2856- sort_stack.push_back (b);
2857- func->run (sort_stack);
2858- return pop (sort_stack).toBool () ^ reverse;
2859- });
2860- return 0 ;
2861- };
2927+ return sort_op (
2928+ getLtFuncFromListOfClassTypes (node),
2929+ /* has_reverse_arg*/ true ,
2930+ /* copy_return_list*/ false );
28622931 },
28632932 aliasAnalysisFromSchema ()),
28642933});
0 commit comments