Skip to content

Add DLPack export + streaming tensor views to OME-Arrow #33

@d33bs

Description

@d33bs

Goal

Make OME-Arrow pixel data (from the OME struct / Parquet / Vortex) consumable by PyTorch and JAX with minimal copies by exposing a small tensor view API that can export DLPack capsules for selected slices, tiles, or batches.


Scope / Deliverables

1. Tensor View Abstraction

  • Define a TensorView (name flexible) returned from the OME-Arrow struct.
  • Supports selecting pixel blocks by indices:
    • scene/image
    • t, z, c
    • spatial ROI (x, y, w, h) or tile index
  • Exposes metadata:
    • shape
    • dtype
    • device
    • layout (e.g., CHW, HWC, TZCHW)
    • strides (if applicable)
  • Enforce a canonical default layout:
    • 2D: CHW
    • 5D: TZCHW
  • Allow explicit layout overrides.

2. DLPack Export

  • Implement TensorView.to_dlpack(...)
  • Signature example:
    to_dlpack(device="cpu" | "cuda", contiguous=True | False)
  • If the view is not representable as a single strided tensor:
    • Raise a clear, actionable error OR
    • Materialize once into a contiguous buffer (preferred: directly on target device).
  • Define and document ownership/lifetime semantics for the returned DLPack capsule.

3. Streaming / Batching

  • Implement an iterator-based API:
    iter_dlpack(
        batch_size=...,
        tiles=...,
        shuffle=False,
        prefetch=...
    )
  • Iterator yields one DLPack capsule per batch or tile.
  • Deterministic behavior when a seed is provided (if shuffling supported).

4. Optional Framework Shims (Nice-to-have)

  • to_torch() thin wrapper using DLPack import utilities.
  • to_jax() thin wrapper using DLPack import utilities.
  • Keep DLPack as the core interoperability surface (avoid hard framework coupling).

Tests (Required)

1. Round-Trip Correctness (CPU)

  • Create a small synthetic OME-Arrow struct (e.g., C=2, H=8, W=8) with known values.
  • Export via to_dlpack() and import into:
    • PyTorch (torch.utils.dlpack.from_dlpack)
    • JAX (jax.dlpack.from_dlpack, if available)
  • Assert:
    • Shapes match
    • Dtypes match
    • Values match exactly

2. Contiguity & Layout

  • Verify default layout (CHW) is correct.
  • Verify requested layouts (e.g., HWC) produce expected permutations.
  • If contiguous=True, assert resulting tensor is contiguous in the consumer framework.

3. Zero-Copy Expectations (CPU)

  • Where feasible, assert buffer sharing (e.g., pointer/storage checks).
  • At minimum, assert contiguous=False does not materialize when the Arrow view is already compatible.

4. Error Paths

  • Non-representable Arrow layouts trigger documented behavior (error or materialization).
  • Invalid device requests raise clear, helpful errors.

5. Iterator Semantics

  • iter_dlpack(batch_size=...) yields the correct number of batches.
  • Concatenated batch content matches expected values.
  • If shuffle is supported:
    • Deterministic output with a fixed seed.

Documentation (Minimal)

  • Short docs section: “Exporting OME-Arrow pixel data via DLPack”
  • One minimal example snippet for:
    • PyTorch
    • JAX

Acceptance Criteria

  • PyTorch can consume OME-Arrow pixel slices/tiles via DLPack with correct values and layout.
  • JAX can consume the same API (or tests skip cleanly if JAX is unavailable).
  • Device placement, contiguity, and ownership semantics are explicit and predictable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions