Skip to content

add merge function for NamedDataStore #8850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,30 @@ def get_named_data_store_output(self) -> NamedDataStoreOutput:
# Clean up empty maps inside self.external_data
self.external_data = {k: v for k, v in self.external_data.items() if len(v) > 0}
return NamedDataStoreOutput(self.buffers, self.pte_data, self.external_data)

def merge_named_data_store(self, other: NamedDataStoreOutput) -> None:
"""
Merge another NamedDataStore into this one.
Args:
other (NamedDataStore): the other NamedDataStore to merge.
Raises:
ValueError: when the key exists in both stores, and corresponding
data is different between them.
"""
# Merge the pte_data.
for key, buffer_idx in other.pte_data.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
)

# Merge the external_data.
for filename, key_to_buffer_idx in other.external_data.items():
for key, buffer_idx in key_to_buffer_idx.items():
self.add_named_data(
key,
other.buffers[buffer_idx].buffer,
other.buffers[buffer_idx].alignment,
external_tag=filename,
)
59 changes: 59 additions & 0 deletions exir/_serialize/test/test_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,62 @@ def test_add_duplicate_key_fail(self) -> None:
self.assertEqual(len(output.pte_data), 1)
self.assertEqual(output.pte_data["key"], 0)
self.assertEqual(len(output.external_data), 0)

def test_merge(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)
store1.add_named_data("key2", b"data2", 16, "file1")

# Check items in the store1.
output = store1.get_named_data_store_output()
self.assertEqual(len(output.buffers), 2)
self.assertEqual(len(output.pte_data), 1)
self.assertEqual(len(output.external_data), 1)
self.assertEqual(len(output.external_data["file1"]), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data1", None, None)
store2.add_named_data("key3", b"data3", None, None)
store2.add_named_data("key4", b"data4", 16, "file1")
store2.add_named_data("key5", b"data5", 16, "file2")

# Check items in store2.
output2 = store2.get_named_data_store_output()
self.assertEqual(len(output2.buffers), 4)
self.assertEqual(len(output2.pte_data), 2)
self.assertEqual(len(output2.external_data), 2)
self.assertEqual(len(output2.external_data["file1"]), 1)
self.assertEqual(len(output2.external_data["file2"]), 1)

# Merge store2 into store1.
store1.merge_named_data_store(output2)

# Check items in store2 are merged into store1.
output = store1.get_named_data_store_output()
# key1, data1 exist in both store1 and store2, so we only have one copy of it.
self.assertEqual(len(output.buffers), 5)
self.assertEqual(len(output.pte_data), 2)
self.assertEqual(len(output.external_data), 2)
self.assertEqual(len(output.external_data["file1"]), 2)
self.assertEqual(len(output.external_data["file2"]), 1)

def test_merge_duplicate_error(self) -> None:
store1 = NamedDataStore()
store1.add_named_data("key1", b"data1", None, None)

# Check items in the store1.
output = store1.get_named_data_store_output()
self.assertEqual(len(output.buffers), 1)
self.assertEqual(len(output.pte_data), 1)

store2 = NamedDataStore()
store2.add_named_data("key1", b"data2", None, None)

# Check items in store2.
output2 = store2.get_named_data_store_output()
self.assertEqual(len(output2.buffers), 1)
self.assertEqual(len(output2.pte_data), 1)

# Merge store2 into store1 raises error as key1 is already in store1
# with different data.
self.assertRaises(ValueError, store1.merge_named_data_store, output2)
Loading