Skip to content

Commit

Permalink
feat: accept file handles as well as paths (#161)
Browse files Browse the repository at this point in the history
* feat: accept file handles

* cleanup around file protocols
  • Loading branch information
tlambert03 authored Jul 9, 2023
1 parent 4ea166e commit fecd852
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 57 deletions.
37 changes: 19 additions & 18 deletions src/nd2/_parse/_chunk_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@

import mmap
import struct
from contextlib import contextmanager
from io import BufferedReader
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, BinaryIO, Iterator, cast
from typing import TYPE_CHECKING, BinaryIO, ContextManager, cast

import numpy as np

if TYPE_CHECKING:
from os import PathLike
from typing import Final
from typing import Final, Iterator

from numpy.typing import DTypeLike

Expand Down Expand Up @@ -71,7 +70,7 @@ def get_version(fh: BinaryIO | StrOrBytesPath) -> tuple[int, int]:
Parameters
----------
fh : BufferedReader | str | bytes | Path
fh : BinaryIO | str | bytes | Path
The file handle or path to the ND2 file.
Returns
Expand All @@ -84,12 +83,14 @@ def get_version(fh: BinaryIO | StrOrBytesPath) -> tuple[int, int]:
ValueError
If the file is not a valid ND2 file or the header chunk is corrupt.
"""
if not isinstance(fh, (BinaryIO, BufferedReader)):
with open(fh, "rb") as fh:
chunk = START_FILE_CHUNK.unpack(fh.read(START_FILE_CHUNK.size))
if hasattr(fh, "read"):
ctx: ContextManager[BinaryIO] = nullcontext(cast("BinaryIO", fh))
else:
# leave it open if it came in open
ctx = open(fh, "rb")

with ctx as fh:
fh.seek(0)
fname = str(fh.name)
chunk = START_FILE_CHUNK.unpack(fh.read(START_FILE_CHUNK.size))

magic, name_length, data_length, name, data = cast("StartFileChunk", chunk)
Expand All @@ -98,15 +99,15 @@ def get_version(fh: BinaryIO | StrOrBytesPath) -> tuple[int, int]:
if magic != ND2_CHUNK_MAGIC:
if magic == JP2_MAGIC:
return (1, 0) # legacy JP2 files are version 1.0
raise ValueError(f"Not a valid ND2 file: {fh.name}. (magic: {magic!r})")
raise ValueError(f"Not a valid ND2 file: {fname}. (magic: {magic!r})")
if name_length != 32 or data_length != 64 or name != ND2_FILE_SIGNATURE:
raise ValueError(f"Corrupt ND2 file header chunk: {fh.name}")
raise ValueError(f"Corrupt ND2 file header chunk: {fname}")

# data will now be something like Ver2.0, Ver3.0, etc.
return (int(chr(data[3])), int(chr(data[5])))


def get_chunkmap(fh: BufferedReader, error_radius: int | None = None) -> ChunkMap:
def get_chunkmap(fh: BinaryIO, error_radius: int | None = None) -> ChunkMap:
"""Read the map of the chunks at the end of an ND2 file.
A Chunkmap is mapping of chunk names (bytes) to (offset, size) pairs.
Expand All @@ -122,7 +123,7 @@ def get_chunkmap(fh: BufferedReader, error_radius: int | None = None) -> ChunkMa
Parameters
----------
fh : BufferedReader
fh : BinaryIO
An open nd2 file. File is assumed to be a valid ND2 file. (use `get_version`)
error_radius : int, optional
If b"ND2 FILEMAP SIGNATURE NAME 0001!" is not found at expected location and
Expand Down Expand Up @@ -176,7 +177,7 @@ def get_chunkmap(fh: BufferedReader, error_radius: int | None = None) -> ChunkMa


def read_nd2_chunk(
fh: BufferedReader, start_position: int, expect_name: bytes | None = None
fh: BinaryIO, start_position: int, expect_name: bytes | None = None
) -> bytes:
"""Read a single chunk in an ND2 file at `start_position`.
Expand All @@ -191,7 +192,7 @@ def read_nd2_chunk(
Parameters
----------
fh : BufferedReader
fh : BinaryIO
An open nd2 file. File is assumed to be a valid ND2 file. (use `get_version`)
start_position : int
The position in the file to start reading the chunk.
Expand Down Expand Up @@ -229,7 +230,7 @@ def read_nd2_chunk(


def _robustly_read_named_chunk(
fh: BufferedReader,
fh: BinaryIO,
start_position: int,
expect_name: bytes = ND2_FILEMAP_SIGNATURE,
search_radius: int | None = None,
Expand All @@ -242,7 +243,7 @@ def _robustly_read_named_chunk(
Parameters
----------
fh : BufferedReader
fh : BinaryIO
An open nd2 file. File is assumed to be a valid ND2 file.
start_position : int
The position in the file to start reading the chunk.
Expand Down Expand Up @@ -276,7 +277,7 @@ def _robustly_read_named_chunk(
raise ValueError(err_msg) from e


def iter_chunks(handle: BufferedReader) -> Iterator[tuple[str, int, int]]:
def iter_chunks(handle: BinaryIO) -> Iterator[tuple[str, int, int]]:
file_size = handle.seek(0, 2)
handle.seek(0)
pos = 0
Expand Down
27 changes: 19 additions & 8 deletions src/nd2/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import warnings
from datetime import datetime
from itertools import product
from typing import TYPE_CHECKING, Mapping, NamedTuple
from typing import TYPE_CHECKING, BinaryIO, NamedTuple

if TYPE_CHECKING:
from os import PathLike
from typing import IO, Any, Callable, ClassVar, Sequence, Union
from typing import Any, Callable, ClassVar, Mapping, Sequence, Union

from nd2.readers import ND2Reader

StrOrBytesPath = Union[str, bytes, PathLike[str], PathLike[bytes]]
StrOrPath = Union[str, PathLike]
FileOrBinaryIO = Union[StrOrPath, BinaryIO]

ListOfDicts = list[dict[str, Any]]
DictOfLists = Mapping[str, Sequence[Any]]
Expand All @@ -24,28 +25,38 @@
VERSION = re.compile(r"^ND2 FILE SIGNATURE CHUNK NAME01!Ver([\d\.]+)$")


def _open_binary(path: StrOrPath) -> BinaryIO:
return open(path, "rb")


def is_supported_file(
path: StrOrBytesPath, open_: Callable[[StrOrBytesPath, str], IO[Any]] = open
path: FileOrBinaryIO,
open_: Callable[[StrOrPath], BinaryIO] = _open_binary,
) -> bool:
"""Return `True` if `path` can be opened as an nd2 file.
Parameters
----------
path : Union[str, bytes, PathLike]
A path to query
open_ : Callable[[StrOrBytesPath, str], IO[Any]]
open_ : Callable[[StrOrBytesPath, str], BinaryIO]
Filesystem opener, by default `builtins.open`
Returns
-------
bool
Whether the can be opened.
"""
with open_(path, "rb") as fh:
return fh.read(4) in (NEW_HEADER_MAGIC, OLD_HEADER_MAGIC)
if isinstance(path, BinaryIO):
path.seek(0)
magic = path.read(4)
else:
with open_(path) as fh:
magic = fh.read(4)
return magic in (NEW_HEADER_MAGIC, OLD_HEADER_MAGIC)


def is_legacy(path: StrOrBytesPath) -> bool:
def is_legacy(path: StrOrPath) -> bool:
"""Return `True` if `path` is a legacy ND2 file.
Parameters
Expand Down
37 changes: 21 additions & 16 deletions src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import threading
import warnings
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, cast, overload

import numpy as np

from nd2 import _util

from ._parse._chunk_decode import get_version
from ._util import AXIS, is_supported_file
from .readers.protocol import ND2Reader

Expand All @@ -21,6 +19,7 @@


if TYPE_CHECKING:
from pathlib import Path
from typing import Any, Sequence, Sized, SupportsInt

import dask.array
Expand All @@ -30,7 +29,13 @@
from typing_extensions import Literal

from ._binary import BinaryLayers
from ._util import DictOfDicts, DictOfLists, ListOfDicts, StrOrBytesPath
from ._util import (
DictOfDicts,
DictOfLists,
FileOrBinaryIO,
ListOfDicts,
StrOrPath,
)
from .structures import (
ROI,
Attributes,
Expand Down Expand Up @@ -91,7 +96,7 @@ class ND2File:

def __init__(
self,
path: Path | str,
path: FileOrBinaryIO,
*,
validate_frames: bool = False,
search_window: int = 100,
Expand All @@ -104,16 +109,15 @@ def __init__(
FutureWarning,
stacklevel=2,
)
self._path = str(path)
self._error_radius: int | None = (
search_window * 1000 if validate_frames else None
)
self._rdr = ND2Reader.create(self._path, self._error_radius)
self._closed = False
self._rdr = ND2Reader.create(path, self._error_radius)
self._path = self._rdr._path
self._lock = threading.RLock()

@staticmethod
def is_supported_file(path: StrOrBytesPath) -> bool:
def is_supported_file(path: StrOrPath) -> bool:
"""Return `True` if the file is supported by this reader."""
return is_supported_file(path)

Expand All @@ -138,12 +142,12 @@ def version(self) -> tuple[int, ...]:
ValueError
If the file is not a valid nd2 file.
"""
return get_version(self._path)
return self._rdr.version()

@property
def path(self) -> str:
"""Path of the image."""
return self._path
return str(self._path)

@property
def is_legacy(self) -> bool:
Expand All @@ -166,7 +170,6 @@ def open(self) -> None:
"""
if self.closed:
self._rdr.open()
self._closed = False

def close(self) -> None:
"""Close file.
Expand All @@ -184,12 +187,11 @@ def close(self) -> None:
"""
if not self.closed:
self._rdr.close()
self._closed = True

@property
def closed(self) -> bool:
"""Return `True` if the file is closed."""
return self._closed
return self._rdr._closed

def __enter__(self) -> ND2File:
"""Open file for reading."""
Expand All @@ -198,7 +200,8 @@ def __enter__(self) -> ND2File:

def __del__(self) -> None:
"""Delete file handle on garbage collection."""
if not getattr(self, "_closed", True):
# if it came in as an open file handle, it's ok to remain open after deletion
if not getattr(self, "closed", True) and not self._rdr._was_open:
warnings.warn(
"ND2File file not closed before garbage collection. "
"Please use `with ND2File(...):` context or call `.close()`.",
Expand All @@ -215,15 +218,17 @@ def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
del state["_rdr"]
del state["_lock"]
state["_closed"] = self.closed
return state

def __setstate__(self, d: dict[str, Any]) -> None:
"""Load state from pickling."""
_was_closed = d.pop("_closed", False)
self.__dict__ = d
self._lock = threading.RLock()
self._rdr = ND2Reader.create(self._path, self._error_radius)

if self._closed:
if _was_closed:
self._rdr.close()

@cached_property
Expand Down Expand Up @@ -1119,7 +1124,7 @@ def __repr__(self) -> str:
"""Return a string representation of the ND2File."""
try:
details = " (closed)" if self.closed else f" {self.dtype}: {self.sizes!r}"
extra = f": {Path(self.path).name!r}{details}"
extra = f": {self._path.name!r}{details}"
except Exception:
extra = ""
return f"<ND2File at {hex(id(self))}{extra}>"
Expand Down
11 changes: 6 additions & 5 deletions src/nd2/readers/_legacy/legacy_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
import warnings
from dataclasses import replace
from typing import TYPE_CHECKING, Any, BinaryIO, DefaultDict, Mapping, cast
from typing import TYPE_CHECKING, DefaultDict, cast

import numpy as np

Expand All @@ -22,11 +22,12 @@

if TYPE_CHECKING:
from collections import defaultdict
from io import BufferedReader
from pathlib import Path
from typing import Any, BinaryIO, Mapping

from typing_extensions import TypedDict

from nd2._util import FileOrBinaryIO

class RawExperimentLoop(TypedDict, total=False):
Type: int
ApplicationDesc: str
Expand Down Expand Up @@ -143,7 +144,7 @@ class PlaneDict(TypedDict, total=False):
class LegacyReader(ND2Reader):
HEADER_MAGIC = _util.OLD_HEADER_MAGIC

def __init__(self, path: str | Path, error_radius: int | None = None) -> None:
def __init__(self, path: FileOrBinaryIO, error_radius: int | None = None) -> None:
super().__init__(path, error_radius)
self._attributes: strct.Attributes | None = None
# super().__init__ called open()
Expand Down Expand Up @@ -415,7 +416,7 @@ def header(self) -> dict:
pos = self.chunkmap[b"jp2h"][0]
except (KeyError, IndexError) as e:
raise KeyError("No valid jp2h header found in file") from e
fh = cast("BufferedReader", self._fh)
fh = cast("BinaryIO", self._fh)
fh.seek(pos + I4s.size + 4) # 4 bytes for "label"
if fh.read(4) != b"ihdr":
raise KeyError("No valid ihdr header found in jp2h header")
Expand Down
4 changes: 2 additions & 2 deletions src/nd2/readers/_modern/modern_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
if TYPE_CHECKING:
import datetime
from os import PathLike
from pathlib import Path

from typing_extensions import Literal, TypeAlias

Expand All @@ -47,6 +46,7 @@
RawTagDict,
RawTextInfoDict,
)
from nd2._util import FileOrBinaryIO

StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
StartFileChunk: TypeAlias = tuple[int, int, int, bytes, bytes]
Expand All @@ -55,7 +55,7 @@
class ModernReader(ND2Reader):
HEADER_MAGIC = _util.NEW_HEADER_MAGIC

def __init__(self, path: str | Path, error_radius: int | None = None) -> None:
def __init__(self, path: FileOrBinaryIO, error_radius: int | None = None) -> None:
super().__init__(path, error_radius)

self._cached_decoded_chunks: dict[bytes, Any] = {}
Expand Down
Loading

0 comments on commit fecd852

Please sign in to comment.