Skip to content

Commit

Permalink
Skip duplicate data during copy from one catalog to another (bluesky#737
Browse files Browse the repository at this point in the history
)

* Add basic duplicate checking to sync.copy

* fix lint

* Gate error override behind skip_duplicates flag

* Change skip_duplicates to on_conflict

* Add changelog entry for conflict handling

* Add tests for conflict behavior on copy
  • Loading branch information
pbeaucage authored May 10, 2024
1 parent addcda6 commit 307524d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Write the date in place of the "Unreleased" in the case a new version is release
dictionary with 'selected' as the key, to match default type/behavior.
- The method `BaseClient.data_sources()` returns dataclass objects instead of
raw dict objects.
- `tiled.client.sync` has conflict handling, with initial options of 'error'
(default), 'warn', and 'skip'

### Fixed

Expand Down
31 changes: 31 additions & 0 deletions tiled/_tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import h5py
import numpy
import pandas
import pytest
import sparse
import tifffile

Expand All @@ -14,6 +15,7 @@
from tiled.client.register import register
from tiled.client.smoke import read
from tiled.client.sync import copy
from tiled.client.utils import ClientError
from tiled.queries import Key
from tiled.server.app import build_app

Expand Down Expand Up @@ -88,6 +90,35 @@ def test_copy_internal():
read(dest, strict=True)


def test_copy_skip_conflict():
with client_factory() as dest:
with client_factory() as source:
populate_internal(source)
copy(source, dest)
copy(source, dest, on_conflict="skip")
assert list(source) == list(dest)
assert list(source["c"]) == list(dest["c"])
read(dest, strict=True)


def test_copy_warn_conflict():
with client_factory() as dest:
with client_factory() as source:
populate_internal(source)
copy(source, dest)
with pytest.warns(UserWarning):
copy(source, dest, on_conflict="warn")


def test_copy_error_conflict():
with client_factory() as dest:
with client_factory() as source:
populate_internal(source)
copy(source, dest)
with pytest.raises(ClientError):
copy(source, dest)


def test_copy_external(tmp_path):
with client_factory(readable_storage=[tmp_path]) as dest:
with client_factory() as source:
Expand Down
53 changes: 38 additions & 15 deletions tiled/client/sync.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import itertools
import warnings

import httpx

from ..structures.core import StructureFamily
from ..structures.data_source import DataSource, Management
from .base import BaseClient
from .utils import ClientError


def copy(
source: BaseClient,
dest: BaseClient,
on_conflict: str = "error",
):
"""
Copy data from one Tiled instance to another.
Expand All @@ -16,6 +21,7 @@ def copy(
----------
source : tiled node
dest : tiled node
on_conflict : str, default 'error', other options 'warn', 'skip'
Examples
--------
Expand All @@ -34,45 +40,50 @@ def copy(
>>> copy(a.items().head(), b)
>>> copy(a.search(...), b)
Copy and ignore duplicates.
>>> copy(a, b, on_conflict = 'skip')
"""
if hasattr(source, "structure_family"):
# looks like a client object
_DISPATCH[source.structure_family](source.include_data_sources(), dest)
_DISPATCH[source.structure_family](
source.include_data_sources(), dest, on_conflict
)
else:
_DISPATCH[StructureFamily.container](dict(source), dest)
_DISPATCH[StructureFamily.container](dict(source), dest, on_conflict)


def _copy_array(source, dest):
def _copy_array(source, dest, on_conflict):
num_blocks = (range(len(n)) for n in source.chunks)
# Loop over each block index --- e.g. (0, 0), (0, 1), (0, 2) ....
for block in itertools.product(*num_blocks):
array = source.read_block(block)
dest.write_block(array, block)


def _copy_awkward(source, dest):
def _copy_awkward(source, dest, on_conflict):
import awkward

array = source.read()
_form, _length, container = awkward.to_buffers(array)
dest.write(container)


def _copy_sparse(source, dest):
def _copy_sparse(source, dest, on_conflict):
num_blocks = (range(len(n)) for n in source.chunks)
# Loop over each block index --- e.g. (0, 0), (0, 1), (0, 2) ....
for block in itertools.product(*num_blocks):
array = source.read_block(block)
dest.write_block(array.coords, array.data, block)


def _copy_table(source, dest):
def _copy_table(source, dest, on_conflict):
for partition in range(source.structure().npartitions):
df = source.read_partition(partition)
dest.write_partition(df, partition)


def _copy_container(source, dest):
def _copy_container(source, dest, on_conflict):
for key, child_node in source.items():
original_data_sources = child_node.include_data_sources().data_sources()
num_data_sources = len(original_data_sources)
Expand Down Expand Up @@ -108,21 +119,33 @@ def _copy_container(source, dest):
raise NotImplementedError(
"Multiple Data Sources in one Node is not supported."
)
node = dest.new(
key=key,
structure_family=child_node.structure_family,
data_sources=data_sources,
metadata=dict(child_node.metadata),
specs=child_node.specs,
)
try:
node = dest.new(
key=key,
structure_family=child_node.structure_family,
data_sources=data_sources,
metadata=dict(child_node.metadata),
specs=child_node.specs,
)
except ClientError as err:
if (
on_conflict == "skip" or on_conflict == "warn"
) and err.response.status_code == httpx.codes.CONFLICT:
if on_conflict == "warn":
warnings.warn("Skipped existing entry")
continue
else:
raise err
if (
original_data_sources
and (original_data_sources[0].management != Management.external)
) or (
child_node.structure_family == StructureFamily.container
and (not original_data_sources)
):
_DISPATCH[child_node.structure_family](child_node, node)
_DISPATCH[child_node.structure_family](
child_node, node, on_conflict=on_conflict
)


_DISPATCH = {
Expand Down

0 comments on commit 307524d

Please sign in to comment.