-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from weigertlab/cli
Add CLI
- Loading branch information
Showing
8 changed files
with
220 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from shell import shell | ||
|
||
from test_data import example_dataset | ||
|
||
|
||
def test_cli_parser(): | ||
result = shell("trackastra") | ||
assert result.code == 0 | ||
|
||
|
||
def test_cli_tracking(): | ||
example_dataset() | ||
result = shell( | ||
"trackastra track -i test_data/img -m test_data/TRA --model-pretrained general_2d" # noqa: RUF100 | ||
) | ||
assert result.code == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
import sys | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from .model import Trackastra | ||
from .tracking.utils import graph_to_ctc | ||
from .utils import str2path | ||
|
||
|
||
def cli(): | ||
p = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False | ||
) | ||
subparsers = p.add_subparsers(help="trackastra") | ||
|
||
p_track = subparsers.add_parser("track", help="Tracking help") | ||
p_track.add_argument( | ||
"-i", | ||
"--imgs", | ||
type=str2path, | ||
required=True, | ||
help="Directory with series of .tif files.", | ||
) | ||
p_track.add_argument( | ||
"-m", | ||
"--masks", | ||
type=str2path, | ||
required=True, | ||
help="Directory with series of .tif files.", | ||
) | ||
p_track.add_argument( | ||
"-o", | ||
"--outdir", | ||
type=str2path, | ||
default=None, | ||
help=( | ||
"Directory for writing results (optional). Default writes to" | ||
" `{masks}_tracked`." | ||
), | ||
) | ||
p_track.add_argument( | ||
"--model-pretrained", | ||
type=str, | ||
default=None, | ||
help="Name of pretrained Trackastra model.", | ||
) | ||
p_track.add_argument( | ||
"--model-custom", | ||
type=str2path, | ||
default=None, | ||
help="Local folder with custom model.", | ||
) | ||
p_track.add_argument( | ||
"--mode", choices=["greedy_nodiv", "greedy", "ilp"], default="greedy" | ||
) | ||
p_track.add_argument("--device", choices=["cuda", "cpu"], default="cuda") | ||
p_track.set_defaults(cmd=_track_from_disk) | ||
|
||
if len(sys.argv) == 1: | ||
p.print_help(sys.stdout) | ||
sys.exit(0) | ||
|
||
args = p.parse_args() | ||
|
||
args.cmd(args) | ||
|
||
|
||
def _track_from_disk(args): | ||
device = "cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu" | ||
|
||
if args.model_pretrained is None == args.model_custom is None: | ||
raise ValueError( | ||
"Please pick a Trackastra model for tracking, either pretrained or a local" | ||
" custom model." | ||
) | ||
|
||
if args.model_pretrained is not None: | ||
model = Trackastra.from_pretrained( | ||
name=args.model_pretrained, | ||
device=device, | ||
) | ||
if args.model_custom is not None: | ||
model = Trackastra.from_folder( | ||
args.model_custom, | ||
device=device, | ||
) | ||
|
||
track_graph, masks = model.track_from_disk( | ||
args.imgs, | ||
args.masks, | ||
mode=args.mode, | ||
) | ||
|
||
if args.outdir is None: | ||
outdir = Path(f"{args.masks}_tracked") | ||
else: | ||
outdir = args.outdir | ||
|
||
outdir.mkdir(parents=True, exist_ok=True) | ||
graph_to_ctc( | ||
track_graph, | ||
masks, | ||
outdir=outdir, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ | |
render_label, | ||
seed, | ||
str2bool, | ||
str2path, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters