Skip to content

Commit

Permalink
DynamoDB: query() should use all range keys when sorting (#7795)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Jun 28, 2024
1 parent 4c99cbb commit bd71c9c
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 28 deletions.
66 changes: 38 additions & 28 deletions moto/dynamodb/models/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,30 +713,30 @@ def query(
f"Range Key comparison but no range key found for index: {index_name}"
)

actual_hash_attr = index_hash_key["AttributeName"]
hash_attrs = [index_hash_key["AttributeName"], self.hash_key_attr]
if index_range_key:
actual_range_attrs = [
range_attrs = [
index_range_key["AttributeName"],
self.range_key_attr,
]
else:
actual_range_attrs = [self.range_key_attr]
range_attrs = [self.range_key_attr]

possible_results = []
for item in self.all_items():
if not isinstance(item, Item):
continue
item_hash_key = item.attrs.get(actual_hash_attr)
if len(actual_range_attrs) == 1:
item_hash_key = item.attrs.get(hash_attrs[0])
if len(range_attrs) == 1:
if item_hash_key and item_hash_key == hash_key:
possible_results.append(item)
else:
item_range_key = item.attrs.get(actual_range_attrs[0]) # type: ignore
item_range_key = item.attrs.get(range_attrs[0]) # type: ignore
if item_hash_key and item_hash_key == hash_key and item_range_key:
possible_results.append(item)
else:
actual_hash_attr = self.hash_key_attr
actual_range_attrs = [self.range_key_attr]
hash_attrs = [self.hash_key_attr]
range_attrs = [self.range_key_attr]

possible_results = [
item
Expand All @@ -746,15 +746,15 @@ def query(

# SORT
if index_name:
if len(actual_range_attrs) == 2:
if len(range_attrs) == 2:
# Convert to float if necessary to ensure proper ordering
def conv(x: DynamoType) -> Any:
return float(x.value) if x.type == "N" else x.value

possible_results.sort(
key=lambda item: ( # type: ignore
conv(item.attrs[actual_range_attrs[0]]) # type: ignore
if item.attrs.get(actual_range_attrs[0]) # type: ignore
conv(item.attrs[range_attrs[0]]) # type: ignore
if item.attrs.get(range_attrs[0]) # type: ignore
else None
)
)
Expand All @@ -777,8 +777,8 @@ def conv(x: DynamoType) -> Any:
if self._item_comes_before_dct(
result,
exclusive_start_key,
actual_hash_attr,
actual_range_attrs,
hash_attrs,
range_attrs,
scan_index_forward,
):
continue
Expand Down Expand Up @@ -926,19 +926,16 @@ def scan(
except IndexError:
index_range_key = None

actual_hash_attr = index_hash_key["AttributeName"]
hash_attrs = [index_hash_key["AttributeName"], self.hash_key_attr]
if index_range_key:
actual_range_attrs = [
index_range_key["AttributeName"],
self.range_key_attr,
]
range_attrs = [index_range_key["AttributeName"], self.range_key_attr]
else:
actual_range_attrs = [self.range_key_attr]
range_attrs = [self.range_key_attr]

items = self.has_idx_items(index_name)
else:
actual_hash_attr = self.hash_key_attr
actual_range_attrs = [self.range_key_attr]
hash_attrs = [self.hash_key_attr]
range_attrs = [self.range_key_attr]

items = self.all_items()

Expand All @@ -951,8 +948,8 @@ def scan(
if self._item_comes_before_dct(
item,
exclusive_start_key,
actual_hash_attr,
actual_range_attrs,
hash_attrs,
range_attrs,
True,
):
continue
Expand Down Expand Up @@ -1012,25 +1009,38 @@ def _item_comes_before_dct(
self,
item: Item,
dct: Dict[str, Any],
hash_key_attr: str,
hash_key_attrs: List[str],
range_key_attrs: List[Optional[str]],
scan_index_forward: bool,
) -> bool:
"""
Does item appear before or at dct relative to sort options?
hash_key_attrs: The list of hash keys.
Includes the key of the GSI (first, if it exists) and the key of the main table.
range_key_attrs: A list of range keys (RK).
Includes the RK-name of the GSI (first, if it exists) and the RK-name of the main table (second - can be None).
When sorting a GSI, we'll try to sort by the GSI RK first.
However, because GSI RK's are not unique (by design), item and dct can have the same RK-value
If that is the case, we compare by the RK of the main table instead
Related: https://github.com/getmoto/moto/issues/7761
"""
dict_hash_val = DynamoType(dct.get(hash_key_attr)) # type: ignore[arg-type]
item_hash_val = item.attrs[hash_key_attr]
if item_hash_val != dict_hash_val:
hash_comes_before = any(
item.attrs[hash_key_attr] != DynamoType(dct.get(hash_key_attr)) # type: ignore
for hash_key_attr in hash_key_attrs
if hash_key_attr in item.attrs
)
if hash_comes_before:
# If hash keys are different, order immediately
return bool(item_hash_val < dict_hash_val) == scan_index_forward
return (
any(
item.attrs[hash_key_attr] < DynamoType(dct.get(hash_key_attr)) # type: ignore
for hash_key_attr in hash_key_attrs
if hash_key_attr in item.attrs
)
== scan_index_forward
)
if not any(range_key_attrs):
# If hash keys match and no range key, items are identical
return True
Expand Down
40 changes: 40 additions & 0 deletions tests/test_dynamodb/test_dynamodb_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,43 @@ def test_query_filter_overlapping_expression_prefixes(self):
"app": "app1",
"nested": {"version": "version1", "contents": ["value1", "value2"]},
}


@pytest.mark.aws_verified
@dynamodb_aws_verified(add_range=False, add_gsi_range=True)
def test_query_gsi_pagination_with_string_gsi_range_no_sk(table_name=None):
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
table = dynamodb.Table(table_name)

for i in range(3, 7):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

for i in range(9, 6, -1):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

for i in range(3):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

page1 = table.query(
KeyConditionExpression=Key("gsi_pk").eq("john"),
IndexName="test_gsi",
Limit=6,
)
assert page1["Count"] == 6
assert page1["ScannedCount"] == 6
assert len(page1["Items"]) == 6

page2 = table.query(
KeyConditionExpression=Key("gsi_pk").eq("john"),
IndexName="test_gsi",
Limit=6,
ExclusiveStartKey=page1["LastEvaluatedKey"],
)
assert page2["Count"] == 4
assert page2["ScannedCount"] == 4
assert len(page2["Items"]) == 4
assert "LastEvaluatedKey" not in page2

results = page1["Items"] + page2["Items"]
subjects = set([int(r["pk"]) for r in results])
assert subjects == set(range(10))
35 changes: 35 additions & 0 deletions tests/test_dynamodb/test_dynamodb_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,41 @@ def test_scan_gsi_pagination_with_string_gsi_range_and_empty_gsi_sk(table_name=N
assert set([r["sk"] for r in results]) == {"0", "1", "2", "7", "8", "9"}


@pytest.mark.aws_verified
@dynamodb_aws_verified(add_range=False, add_gsi_range=True)
def test_scan_gsi_pagination_with_string_gsi_range_no_sk(table_name=None):
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
table = dynamodb.Table(table_name)

for i in range(3, 7):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

for i in range(9, 6, -1):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

for i in range(3):
table.put_item(Item={"pk": f"{i}", "gsi_pk": "john", "gsi_sk": "jane"})

page1 = table.scan(IndexName="test_gsi", Limit=6)
assert page1["Count"] == 6
assert page1["ScannedCount"] == 6
assert len(page1["Items"]) == 6

page2 = table.scan(
IndexName="test_gsi",
Limit=6,
ExclusiveStartKey=page1["LastEvaluatedKey"],
)
assert page2["Count"] == 4
assert page2["ScannedCount"] == 4
assert len(page2["Items"]) == 4
assert "LastEvaluatedKey" not in page2

results = page1["Items"] + page2["Items"]
subjects = set([int(r["pk"]) for r in results])
assert subjects == set(range(10))


@mock_aws
class TestFilterExpression:
def test_scan_filter(self):
Expand Down

0 comments on commit bd71c9c

Please sign in to comment.