Skip to content

Commit

Permalink
feat: add row filtering to nd2.index, as well as binary/roi data (#151)
Browse files Browse the repository at this point in the history
* feat: extend index [wip]

* feat: finished
  • Loading branch information
tlambert03 authored Jun 26, 2023
1 parent 3f95cf8 commit 757767b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
112 changes: 87 additions & 25 deletions src/nd2/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Iterable, Iterator, Sequence, cast, no_type_check
from typing import Any, Iterable, Iterator, Sequence, cast, no_type_check

from typing_extensions import TypedDict

Expand All @@ -33,6 +33,8 @@ class Record(TypedDict):
dtype: str
shape: list[int]
axes: str
binary: bool
rois: bool
software_name: str
software_version: str
grabber: str
Expand All @@ -48,9 +50,11 @@ def index_file(path: Path) -> Record:
if nd.is_legacy:
software: dict = {}
acquired: str | None = ""
binary = False
else:
software = nd._rdr._app_info() # type: ignore
acquired = nd._rdr._acquisition_date() # type: ignore
binary = nd.binary_data is not None

stat = path.stat()
exp = [(x.type, x.count) for x in nd.experiment]
Expand All @@ -69,6 +73,8 @@ def index_file(path: Path) -> Record:
"dtype": str(nd.dtype),
"shape": list(shape),
"axes": "".join(axes),
"binary": binary,
"rois": bool(nd.rois),
"software_name": software.get("SWNameString", ""),
"software_version": software.get("VersionString", ""),
"grabber": software.get("GrabberString", ""),
Expand Down Expand Up @@ -96,29 +102,45 @@ def _index_files(
return results


def _pretty_print_table(data: list[Record]) -> None:
def _pretty_print_table(data: list[Record], sort_column: str | None = None) -> None:
try:
from rich.console import Console
from rich.table import Table

except ImportError:
raise sys.exit(
"rich is required to print a pretty table. "
"Install it with `pip install rich`."
) from None

table = Table(show_header=True, header_style="bold")
headers = list(data[0])

# add headers, and highlight any sorted columns
sort_col = ""
if sort_column:
sort_col = (sort_column or "").rstrip("-")
direction = " ↓" if sort_column.endswith("-") else " ↑"
for header in headers:
if header == sort_col:
table.add_column(header + direction, style="green")
else:
table.add_column(header)

for header in data[0]:
table.add_column(header)
for row in data:
table.add_row(*[str(value) for value in row.values()])
table.add_row(*[_strify(value) for value in row.values()])

Console().print(table)


def _strify(val: Any) -> str:
if isinstance(val, bool):
return "✅" if val else ""
return str(val)


def _print_csv(records: list[Record], skip_header: bool = False) -> None:
import csv
import sys

writer = csv.DictWriter(sys.stdout, fieldnames=records[0].keys())
if not skip_header:
Expand Down Expand Up @@ -191,36 +213,79 @@ def _parse_args(argv: Sequence[str] = ()) -> argparse.Namespace:
action="store_true",
help="Don't write the CSV header",
)
parser.add_argument(
"--filter",
"-F",
type=str,
action="append",
help="Filter the output. Each filter "
"should be a python expression (string)\nthat evaluates to True or False. "
"It will be evaluated in the context\nof each row. You can use any of the "
"column names as variables.\ne.g.: \"acquired > '2020' and kb < 500\". (May "
"be used multiple times).",
)

return parser.parse_args(argv or sys.argv[1:])


@no_type_check
def _filter_data(
data: list[Record],
to_include: Sequence[str] = (),
sort_by: str | None = None,
include: str | None = None,
exclude: str | None = None,
filters: Sequence[str] = (),
) -> list[Record]:
unrecognized = set(to_include) - set(HEADERS)
"""Filter and sort the data.
Parameters
----------
data : list[Record]
the data to filter
sort_by : str | None, optional
Name of column to sort by, by default None
include : str | None, optional
Comma-separated list of columns to include, by default None
exclude : str | None, optional
Comma-separated list of columns to exclude, by default None
filters : Sequence[str], optional
Sequence of python expression strings to filter the data, by default ()
Returns
-------
list[Record]
_description_
"""
includes = include.split(",") if include else []
unrecognized = set(includes) - set(HEADERS)
if unrecognized: # pragma: no cover
print(f"Unrecognized columns: {', '.join(unrecognized)}", file=sys.stderr)
to_include = [x for x in to_include if x not in unrecognized]
includes = [x for x in includes if x not in unrecognized]

if sort_by:
if sort_by.endswith("-"):
data.sort(key=lambda x: x[sort_by[:-1]], reverse=True)
else:
data.sort(key=lambda x: x[sort_by])

if to_include:
if includes:
# preserve order of to_include
data = [{h: row[h] for h in to_include} for row in data]
data = [{h: row[h] for h in includes} for row in data]

to_exclude = cast("list[str]", exclude.split(",") if exclude else [])

if to_exclude:
data = [{h: row[h] for h in HEADERS if h not in to_exclude} for row in data]

if sort_by:
if sort_by.endswith("-"):
data.sort(key=lambda x: x[sort_by[:-1]], reverse=True)
else:
data.sort(key=lambda x: x[sort_by])
if filters:
# filters are in the form of a string expression, to be evaluated
# against each row. For example, "'TimeLoop' in experiment"
for f in filters:
try:
data = [row for row in data if bool(eval(f, None, row))]
except Exception as e: # pragma: no cover
print(f"Error evaluating filter {f!r}: {e}", file=sys.stderr)
sys.exit(1)

return data

Expand All @@ -229,20 +294,17 @@ def main(argv: Sequence[str] = ()) -> None:
"""Index ND2 files and print the results as a table."""
args = _parse_args(argv)

to_include = cast("list[str]", args.include.split(",") if args.include else [])
if args.sort_by and to_include and args.sort_by not in to_include:
raise sys.exit( # pragma: no cover
f"The sort column {args.sort_by!r} must be in the "
f"included columns: {to_include!r}."
)

data = _index_files(paths=args.paths, recurse=args.recurse, glob=args.glob_pattern)
data = _filter_data(
data, to_include=to_include, sort_by=args.sort_by, exclude=args.exclude
data,
sort_by=args.sort_by,
include=args.include,
exclude=args.exclude,
filters=args.filter,
)

if args.format == "table":
_pretty_print_table(data)
_pretty_print_table(data, args.sort_by)
elif args.format == "csv":
_print_csv(data, args.no_header)
elif args.format == "json":
Expand Down
1 change: 1 addition & 0 deletions src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def rois(self) -> dict[int, ROI]:
try:
_rois = [ROI._from_meta_dict(d) for d in dicts]
except Exception as e: # pragma: no cover
return {}
raise ValueError(f"Could not parse ROI metadata: {e}") from e
return {r.id: r for r in _rois}

Expand Down
8 changes: 8 additions & 0 deletions src/nd2/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,14 @@ def _from_meta_dict(cls, val: dict) -> ROI:
)


class T(TypedDict):
Id: int
Info: dict
GUID: str
AnimParams_Size: int
# AnimParams_{i}: dict


@dataclass
class AnimParam:
"""Parameters of ROI position/shape."""
Expand Down
16 changes: 11 additions & 5 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,31 @@ def test_format(records, fmt, capsys):
filtered = nd2.index._filter_data(records)

if fmt == "table":
nd2.index._pretty_print_table(filtered)
nd2.index._pretty_print_table(filtered, sort_column="name")
elif fmt == "csv":
nd2.index._print_csv(filtered)
elif fmt == "json":
nd2.index._print_json(filtered)
captured = capsys.readouterr()
assert "path" in captured.out
assert captured.out
assert not captured.err


@pytest.mark.parametrize(
"filters",
[
{},
{"to_include": ["path", "name", "version"]},
{"include": "path,name,version"},
{"sort_by": "version"},
{"sort_by": "version-"},
{"exclude": "path"},
{"filters": ("'TimeLoop' in experiment",)},
{"filters": ["acquired > '2020' and kb < 500"], "sort_by": "kb-"},
],
)
def test_filter_data(records, filters: dict):
def test_filter_data(records, filters: dict) -> None:
filtered = nd2.index._filter_data(records, **filters)
assert isinstance(filtered, list)
assert len(filtered) == len(records)
if filters.get("to_include"):
assert len(filtered[0]) == len(filters["to_include"])
sb = filters.get("sort_by")
Expand All @@ -48,6 +50,10 @@ def test_filter_data(records, filters: dict):
assert first_version == "3.0" if sb.endswith("-") else "1.0"
if filters.get("exclude"):
assert "path" not in filtered[0]
if filters.get("filters"):
assert len(filtered) < len(records)
else:
assert len(filtered) == len(records)


def test_index(capsys):
Expand Down

0 comments on commit 757767b

Please sign in to comment.