Skip to content

Commit

Permalink
Merge pull request #5 from weigertlab/cli
Browse files Browse the repository at this point in the history
Add CLI
  • Loading branch information
bentaculum authored Jun 14, 2024
2 parents 9125010 + 647d326 commit 4c9c862
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 7 deletions.
8 changes: 7 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ author = Benjamin Gallusser, Martin Weigert
author_email = benjamin.gallusser@epfl.ch, martin.weigert@epfl.ch,
dynamic = ["version"]
license = BSD 3-Clause License
description = Tracking by Association with Transformer
description = Tracking by Association with Transformers
long_description = file: README.md
long_description_content_type = text/markdown
classifiers =
Expand Down Expand Up @@ -46,6 +46,7 @@ ilp =
motile >= 0.2
dev =
pytest
shell
ruff
black
mypy
Expand All @@ -54,3 +55,8 @@ dev =
build
test =
pytest
shell

[options.entry_points]
console_scripts =
trackastra = trackastra.cli:cli
16 changes: 16 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from shell import shell

from test_data import example_dataset


def test_cli_parser():
result = shell("trackastra")
assert result.code == 0


def test_cli_tracking():
example_dataset()
result = shell(
"trackastra track -i test_data/img -m test_data/TRA --model-pretrained general_2d" # noqa: RUF100
)
assert result.code == 0
16 changes: 15 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@


def example_dataset():
img_dir = Path("test_data/img")
img_dir.mkdir(exist_ok=True, parents=True)
img = np.array(
[
[0, 1, 1], # t=0
[0, 1, 0], # t=1
[1, 1, 0], # t=2
]
)
img = np.expand_dims(img, -1)

for i in range(img.shape[0]):
imwrite(img_dir / f"emp_{i}.tif", img[i])

tra_dir = Path("test_data/TRA")
tra_dir.mkdir(exist_ok=True, parents=True)

Expand All @@ -20,7 +34,7 @@ def example_dataset():
],
dtype=int,
)
np.savetxt(tra_dir / "man_track.txt", man_track)
np.savetxt(tra_dir / "man_track.txt", man_track, fmt="%i")

y = np.array(
[
Expand Down
106 changes: 106 additions & 0 deletions trackastra/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse
import sys
from pathlib import Path

import torch

from .model import Trackastra
from .tracking.utils import graph_to_ctc
from .utils import str2path


def cli():
p = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False
)
subparsers = p.add_subparsers(help="trackastra")

p_track = subparsers.add_parser("track", help="Tracking help")
p_track.add_argument(
"-i",
"--imgs",
type=str2path,
required=True,
help="Directory with series of .tif files.",
)
p_track.add_argument(
"-m",
"--masks",
type=str2path,
required=True,
help="Directory with series of .tif files.",
)
p_track.add_argument(
"-o",
"--outdir",
type=str2path,
default=None,
help=(
"Directory for writing results (optional). Default writes to"
" `{masks}_tracked`."
),
)
p_track.add_argument(
"--model-pretrained",
type=str,
default=None,
help="Name of pretrained Trackastra model.",
)
p_track.add_argument(
"--model-custom",
type=str2path,
default=None,
help="Local folder with custom model.",
)
p_track.add_argument(
"--mode", choices=["greedy_nodiv", "greedy", "ilp"], default="greedy"
)
p_track.add_argument("--device", choices=["cuda", "cpu"], default="cuda")
p_track.set_defaults(cmd=_track_from_disk)

if len(sys.argv) == 1:
p.print_help(sys.stdout)
sys.exit(0)

args = p.parse_args()

args.cmd(args)


def _track_from_disk(args):
device = "cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu"

if args.model_pretrained is None == args.model_custom is None:
raise ValueError(
"Please pick a Trackastra model for tracking, either pretrained or a local"
" custom model."
)

if args.model_pretrained is not None:
model = Trackastra.from_pretrained(
name=args.model_pretrained,
device=device,
)
if args.model_custom is not None:
model = Trackastra.from_folder(
args.model_custom,
device=device,
)

track_graph, masks = model.track_from_disk(
args.imgs,
args.masks,
mode=args.mode,
)

if args.outdir is None:
outdir = Path(f"{args.masks}_tracked")
else:
outdir = args.outdir

outdir.mkdir(parents=True, exist_ok=True)
graph_to_ctc(
track_graph,
masks,
outdir=outdir,
)
62 changes: 58 additions & 4 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import yaml
from tqdm import tqdm

from ..data import build_windows, get_features
from ..tracking import build_graph, track_greedy
from ..data import build_windows, get_features, load_tiff_timeseries
from ..tracking import TrackGraph, build_graph, track_greedy
from ..utils import normalize
from .model import TrackingTransformer
from .predict import predict_windows
Expand Down Expand Up @@ -140,10 +140,64 @@ def track(
self,
imgs: np.ndarray,
masks: np.ndarray,
mode: Literal["greedy", "ilp"] = "greedy",
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
progbar_class=tqdm,
**kwargs,
):
) -> TrackGraph:
predictions = self._predict(imgs, masks, progbar_class=progbar_class)
track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
return track_graph

def track_from_disk(
self,
imgs_path: Path,
masks_path: Path,
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
**kwargs,
) -> tuple[TrackGraph, np.ndarray]:
"""Track directly from two series of tiff files.
Args:
imgs_path:
Directory containing a series of numbered tiff files.
Each file contains an image of shape (C),(Z),Y,X.
masks_path:
Directory containing a series of numbered tiff files.
Each file contains an image of shape (Z), Y, X.
mode (optional):
Mode for candidate graph pruning.
"""
if not imgs_path.exists():
raise FileNotFoundError(f"{imgs_path=} does not exist.")
if not masks_path.exists():
raise FileNotFoundError(f"{masks_path=} does not exist.")

if not imgs_path.is_dir() or not masks_path.is_dir():
raise NotImplementedError("Currently only tiff sequences are supported.")

imgs = load_tiff_timeseries(imgs_path)
masks = load_tiff_timeseries(masks_path)

if len(imgs) != len(masks):
raise RuntimeError(
f"#imgs and #masks do not match. Found {len(imgs)} images,"
f" {len(masks)} masks."
)

if imgs.ndim - 1 == masks.ndim:
if imgs[1] == 1:
logger.info(
"Found a channel dimension with a single channel. Removing dim."
)
masks = np.squeeze(masks, 1)
else:
raise RuntimeError(
"Trackastra currently only supports single channel images."
)

if imgs.shape != masks.shape:
raise RuntimeError(
f"Img shape {imgs.shape} and mask shape {masks. shape} do not match."
)

return self.track(imgs, masks, mode, **kwargs), masks
6 changes: 5 additions & 1 deletion trackastra/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def graph_to_napari_tracks(

def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray):
"""Sanity check of all labels in a CTC dataframe are present in the masks."""
# Check for empty df
if len(df) == 0 and np.all(masks == 0):
return True

for t in range(df.t1.min(), df.t1.max()):
sub = df[(df.t1 <= t) & (df.t2 >= t)]
sub_lab = set(sub.label)
Expand Down Expand Up @@ -271,7 +275,7 @@ def graph_to_ctc(

rows.append([label, t1, t2, node_to_tracklets[_parent]])

df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"])
df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"], dtype=int)

masks = np.stack(masks)

Expand Down
1 change: 1 addition & 0 deletions trackastra/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
render_label,
seed,
str2bool,
str2path,
)
12 changes: 12 additions & 0 deletions trackastra/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import random
import sys
from pathlib import Path
from timeit import default_timer

import matplotlib
Expand Down Expand Up @@ -437,6 +438,17 @@ def str2bool(x: str) -> bool:
raise ValueError(f"'{x}' does not seem to be boolean.")


def str2path(x: str) -> Path:
"""Cast string to resolved absolute path.
Useful for parsing command line arguments.
"""
if not isinstance(x, str):
raise TypeError("String expected.")
else:
return Path(x).expanduser().resolve()


if __name__ == "__main__":
A = torch.rand(50, 50)
idx = torch.tensor([0, 10, 20, A.shape[0]])
Expand Down

0 comments on commit 4c9c862

Please sign in to comment.