Skip to content

Commit 30bc19d

Browse files
smessmerfacebook-github-bot
authored andcommitted
dictKeys and dictItems ops on typed dicts return typed lists (pytorch#23270)
Summary: Pull Request resolved: pytorch#23270 ghstack-source-id: 87389530 Differential Revision: D16448942 fbshipit-source-id: e6b578f0e97776112259d7ea38e143e4716ec273
1 parent c8817f9 commit 30bc19d

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

aten/src/ATen/core/Dict.h

+6
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ class Dict final {
348348
* having to reallocate or rehash.
349349
*/
350350
void reserve(size_type count) const;
351+
352+
353+
// private API for now because the return type will change to TypePtr
354+
// instead of optional<TypePtr> once types are mandatory.
355+
optional<TypePtr> _keyType() const;
356+
optional<TypePtr> _valueType() const;
351357
};
352358

353359
namespace impl {

aten/src/ATen/core/Dict_inl.h

+16
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,20 @@ void Dict<Key, Value>::reserve(size_type count) const {
189189
impl_->dict.reserve(count);
190190
}
191191

192+
template<class Key, class Value>
193+
optional<TypePtr> Dict<Key, Value>::_keyType() const {
194+
if (!impl_->elementTypes.has_value()) {
195+
return c10::nullopt;
196+
}
197+
return impl_->elementTypes->keyType;
198+
}
199+
200+
template<class Key, class Value>
201+
optional<TypePtr> Dict<Key, Value>::_valueType() const {
202+
if (!impl_->elementTypes.has_value()) {
203+
return c10::nullopt;
204+
}
205+
return impl_->elementTypes->valueType;
206+
}
207+
192208
}

torch/csrc/jit/register_prim_ops.cpp

+26-8
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,15 @@ int dictLen(Stack& stack) {
18411841

18421842
template <unsigned int Index, typename Elem>
18431843
c10::List<Elem> makeListForDictKeysOrValues(
1844+
const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
18441845
const std::vector<std::pair<IValue, IValue>>& order) {
1846+
TORCH_INTERNAL_ASSERT(
1847+
(!std::get<Index>(types).has_value())
1848+
|| (*std::get<Index>(types) == getTypePtr<Elem>()),
1849+
"Type mismatch when trying to get a List of keys/values from Dict. ",
1850+
"Type in Dict is ", toString(*std::get<Index>(types)),
1851+
". Type in List is ", toString(getTypePtr<Elem>()),
1852+
". Index is ", c10::guts::to_string(Index));
18451853
c10::List<Elem> values;
18461854
values.reserve(order.size());
18471855
for (const auto& item : order) {
@@ -1852,8 +1860,12 @@ c10::List<Elem> makeListForDictKeysOrValues(
18521860

18531861
template <unsigned int Index>
18541862
c10::impl::GenericList makeGenericListForDictKeysOrValues(
1863+
const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
18551864
const std::vector<std::pair<IValue, IValue>>& order) {
1856-
auto values = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1865+
auto type = std::get<Index>(types);
1866+
auto values = type.has_value()
1867+
? c10::impl::GenericList(*type)
1868+
: c10::impl::GenericList(c10::impl::deprecatedUntypedList());
18571869
values.reserve(order.size());
18581870
for (const auto& item : order) {
18591871
values.push_back(std::get<Index>(item));
@@ -1865,17 +1877,19 @@ template <unsigned int Index>
18651877
Operation dictKeysOrValues(const Node* n) {
18661878
auto outputType = n->output()->type()->expect<ListType>();
18671879
return [=](Stack& stack) -> int {
1868-
const auto& order = iterationOrder(pop(stack).toGenericDict());
1880+
auto dict = pop(stack).toGenericDict();
1881+
const auto& order = iterationOrder(dict);
1882+
const auto types = std::make_pair(dict._keyType(), dict._valueType());
18691883
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
1870-
push(stack, makeListForDictKeysOrValues<Index, at::Tensor>(order));
1884+
push(stack, makeListForDictKeysOrValues<Index, at::Tensor>(types, order));
18711885
} else if (outputType->getElementType() == IntType::get()) {
1872-
push(stack, makeListForDictKeysOrValues<Index, int64_t>(order));
1886+
push(stack, makeListForDictKeysOrValues<Index, int64_t>(types, order));
18731887
} else if (outputType->getElementType() == FloatType::get()) {
1874-
push(stack, makeListForDictKeysOrValues<Index, double>(order));
1888+
push(stack, makeListForDictKeysOrValues<Index, double>(types, order));
18751889
} else if (outputType->getElementType() == BoolType::get()) {
1876-
push(stack, makeListForDictKeysOrValues<Index, bool>(order));
1890+
push(stack, makeListForDictKeysOrValues<Index, bool>(types, order));
18771891
} else {
1878-
push(stack, makeGenericListForDictKeysOrValues<Index>(order));
1892+
push(stack, makeGenericListForDictKeysOrValues<Index>(types, order));
18791893
}
18801894
return 0;
18811895
};
@@ -1999,7 +2013,11 @@ int dictUpdate(Stack& stack) {
19992013

20002014
int dictItems(Stack& stack) {
20012015
auto dict = pop(stack).toGenericDict();
2002-
auto items = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
2016+
auto key_type = dict._keyType();
2017+
auto value_type = dict._valueType();
2018+
auto items = (key_type.has_value() && value_type.has_value())
2019+
? c10::impl::GenericList(TupleType::create({*key_type, *value_type}))
2020+
: c10::impl::GenericList(c10::impl::deprecatedUntypedList());
20032021
items.reserve(dict.size());
20042022
for (const auto& item : iterationOrder(dict)) {
20052023
items.emplace_back(c10::ivalue::Tuple::create({item.first, item.second}));

0 commit comments

Comments
 (0)