Skip to content

Commit

Permalink
Speed up "happy path" for forbide_duplicate_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
qpwo authored and Ryan Gabbard committed Jun 21, 2019
1 parent 9127fc2 commit f198390
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
47 changes: 26 additions & 21 deletions immutablecollections/_immutabledict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TypeVar,
Tuple,
Set,
List,
Iterator,
Optional,
Union,
Expand Down Expand Up @@ -42,6 +41,8 @@ def immutabledict(
Mappings may be specified as a sequence of key-value pairs or as another ``ImmutableDict`` or
(on Python 3.7+ and CPython 3.6+) as a built-in ``dict``.
If ``forbid_duplicate_keys=True`` and *iterable* contains duplicate keys, then raise ValueError.
The iteration order of the created keys, values, and items of the resulting ``ImmutableDict``
will match *iterable*.
Expand Down Expand Up @@ -70,29 +71,33 @@ def immutabledict(
)

if forbid_duplicate_keys:
keys: List[KT]
# `for x in dict` grabs keys, but `for x in pairs` grabs pairs, so we must branch
# We check for duplicate elements by comparing the original iterable length with the output
# dict length. Some iterables don't provide a __len__ or are consumed by iteration, so we
# listify the iterable to be safe. Calling list(dict) gets just keys, so we must grab items
# in that case:
if isinstance(iterable, Mapping):
# duplicate keys are possible if input is e.g. a multidict
keys = list(iterable.keys())
iterable = list(iterable.items())
else:
# iterable must be a (key, value) pair iterable
iterable = list(iterable) # in case iterable is consumed by iteration
keys = [key for key, value in iterable]
seen: Set[KT] = set()
duplicated: Set[KT] = set()
for key in keys:
if key in seen:
duplicated.add(key)
else:
seen.add(key)
if duplicated:
raise ValueError(
"forbid_duplicate_keys=True, but some keys "
"occur multiple times in input: {}".format(duplicated)
)
iterable = list(iterable) # iterable is of key-value pairs
original_length = len(iterable) # must be recorded here for mypy to be happy

ret: ImmutableDict[KT, VT] = _RegularDictBackedImmutableDict(iterable)

if forbid_duplicate_keys and len(ret) != original_length:
seen_once: Set[KT] = set()
seen_twice: Set[KT] = set()
# iterable has been made a list and so will not be consumed by iteration:
for key, _ in iterable:
if key not in seen_once:
seen_once.add(key)
else:
seen_twice.add(key)
# seen_twice is guaranteed to be nonempty
raise ValueError(
"forbid_duplicate_keys=True, but some keys "
f"occur multiple times in input: {seen_twice}"
)

if ret:
return ret
else:
Expand All @@ -104,7 +109,7 @@ def immutabledict_from_unique_keys(
) -> "ImmutableDict[KT, VT]":
"""
Create an immutable dictionary with the given mappings, but raise ValueError if
*iterable* contains the same item twice. More information in `immutabledict`
*iterable* contains the same key twice. More information in `immutabledict`
"""
return immutabledict(iterable, forbid_duplicate_keys=True)

Expand Down
23 changes: 18 additions & 5 deletions immutablecollections/_immutableset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Iterator,
KeysView,
List,
Set,
MutableSet,
Optional,
overload,
Expand Down Expand Up @@ -92,20 +93,32 @@ def immutableset(
"iteration order, specify disable_order_check=True"
)

duplicated = set() # only used if forbid_duplicate_elements=True
if forbid_duplicate_elements:
# We check for duplicate elements by comparing the original iterable length with the output
# set length. Some iterables don't provide a __len__ or are consumed by iteration, so we
# listify the iterable to be safe.
iterable = list(iterable)
original_length = len(iterable) # must be recorded here for mypy to be happy

iteration_order = []
containment_set: MutableSet[T] = set()
for value in iterable:
if value not in containment_set:
containment_set.add(value)
iteration_order.append(value)
elif forbid_duplicate_elements:
duplicated.add(value)

if forbid_duplicate_elements and duplicated:
if forbid_duplicate_elements and len(containment_set) != original_length:
seen_once: Set[T] = set()
seen_twice: Set[T] = set()
for item in iterable:
if item not in seen_once:
seen_once.add(item)
else:
seen_twice.add(item)
# seen_twice is guaranteed to be nonempty
raise ValueError(
"forbid_duplicate_elements=True, but some elements "
"occur multiple times in input: {}".format(duplicated)
f"occur multiple times in input: {seen_twice}"
)

if iteration_order:
Expand Down

0 comments on commit f198390

Please sign in to comment.