@@ -1841,7 +1841,15 @@ int dictLen(Stack& stack) {
1841
1841
1842
1842
template <unsigned int Index, typename Elem>
1843
1843
c10::List<Elem> makeListForDictKeysOrValues (
1844
+ const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
1844
1845
const std::vector<std::pair<IValue, IValue>>& order) {
1846
+ TORCH_INTERNAL_ASSERT (
1847
+ (!std::get<Index>(types).has_value ())
1848
+ || (*std::get<Index>(types) == getTypePtr<Elem>()),
1849
+ " Type mismatch when trying to get a List of keys/values from Dict. " ,
1850
+ " Type in Dict is " , toString (*std::get<Index>(types)),
1851
+ " . Type in List is " , toString (getTypePtr<Elem>()),
1852
+ " . Index is " , c10::guts::to_string (Index));
1845
1853
c10::List<Elem> values;
1846
1854
values.reserve (order.size ());
1847
1855
for (const auto & item : order) {
@@ -1852,8 +1860,12 @@ c10::List<Elem> makeListForDictKeysOrValues(
1852
1860
1853
1861
template <unsigned int Index>
1854
1862
c10::impl::GenericList makeGenericListForDictKeysOrValues (
1863
+ const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
1855
1864
const std::vector<std::pair<IValue, IValue>>& order) {
1856
- auto values = c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
1865
+ auto type = std::get<Index>(types);
1866
+ auto values = type.has_value ()
1867
+ ? c10::impl::GenericList (*type)
1868
+ : c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
1857
1869
values.reserve (order.size ());
1858
1870
for (const auto & item : order) {
1859
1871
values.push_back (std::get<Index>(item));
@@ -1865,17 +1877,19 @@ template <unsigned int Index>
1865
1877
Operation dictKeysOrValues (const Node* n) {
1866
1878
auto outputType = n->output ()->type ()->expect <ListType>();
1867
1879
return [=](Stack& stack) -> int {
1868
- const auto & order = iterationOrder (pop (stack).toGenericDict ());
1880
+ auto dict = pop (stack).toGenericDict ();
1881
+ const auto & order = iterationOrder (dict);
1882
+ const auto types = std::make_pair (dict._keyType (), dict._valueType ());
1869
1883
if (outputType->getElementType ()->isSubtypeOf (TensorType::get ())) {
1870
- push (stack, makeListForDictKeysOrValues<Index, at::Tensor>(order));
1884
+ push (stack, makeListForDictKeysOrValues<Index, at::Tensor>(types, order));
1871
1885
} else if (outputType->getElementType () == IntType::get ()) {
1872
- push (stack, makeListForDictKeysOrValues<Index, int64_t >(order));
1886
+ push (stack, makeListForDictKeysOrValues<Index, int64_t >(types, order));
1873
1887
} else if (outputType->getElementType () == FloatType::get ()) {
1874
- push (stack, makeListForDictKeysOrValues<Index, double >(order));
1888
+ push (stack, makeListForDictKeysOrValues<Index, double >(types, order));
1875
1889
} else if (outputType->getElementType () == BoolType::get ()) {
1876
- push (stack, makeListForDictKeysOrValues<Index, bool >(order));
1890
+ push (stack, makeListForDictKeysOrValues<Index, bool >(types, order));
1877
1891
} else {
1878
- push (stack, makeGenericListForDictKeysOrValues<Index>(order));
1892
+ push (stack, makeGenericListForDictKeysOrValues<Index>(types, order));
1879
1893
}
1880
1894
return 0 ;
1881
1895
};
@@ -1999,7 +2013,11 @@ int dictUpdate(Stack& stack) {
1999
2013
2000
2014
int dictItems (Stack& stack) {
2001
2015
auto dict = pop (stack).toGenericDict ();
2002
- auto items = c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
2016
+ auto key_type = dict._keyType ();
2017
+ auto value_type = dict._valueType ();
2018
+ auto items = (key_type.has_value () && value_type.has_value ())
2019
+ ? c10::impl::GenericList (TupleType::create ({*key_type, *value_type}))
2020
+ : c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
2003
2021
items.reserve (dict.size ());
2004
2022
for (const auto & item : iterationOrder (dict)) {
2005
2023
items.emplace_back (c10::ivalue::Tuple::create ({item.first , item.second }));
0 commit comments