Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 228 additions & 55 deletions cellmap_flow/blockwise/blockwise_processor.py

Large diffs are not rendered by default.

36 changes: 13 additions & 23 deletions cellmap_flow/blockwise/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,31 @@
from cellmap_flow.blockwise import CellMapFlowBlockwiseProcessor


@click.group()
@click.command()
@click.argument("yaml_config", type=click.Path(exists=True))
@click.option(
"-c",
"--client",
is_flag=True,
default=False,
help="Run as client if this flag is set.",
)
@click.option(
"--log-level",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
)
def cli(log_level):

def cli(yaml_config, client, log_level):
logging.basicConfig(level=getattr(logging, log_level.upper()))


logger = logging.getLogger(__name__)


@cli.command()
@click.option(
"-y",
"--yaml_config",
required=True,
type=click.Path(exists=True),
help="The path to the YAML file.",
)
@click.option(
"-c",
"--client",
is_flag=True,
default=False,
help="Run as client if this flag is set.",
)
def run(yaml_config, client):
is_server = not client
process = CellMapFlowBlockwiseProcessor(yaml_config, create=is_server)
if is_server:
process.run()
else:
process.client()


logger = logging.getLogger(__name__)
5 changes: 4 additions & 1 deletion cellmap_flow/cli/fly_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def main():
if "charge_group" not in data:
raise ValueError("charge_group is required in the YAML file")
charge_group = data["charge_group"]

input_size = tuple(data.get("input_size", (178, 178, 178)))
output_size = tuple(data.get("output_size", (56, 56, 56)))
g.charge_group = charge_group
threads = []
for run_name, run_items in data["runs"].items():
Expand All @@ -51,6 +52,8 @@ def main():
input_voxel_size=res,
output_voxel_size=res,
name=run_name,
input_size=input_size,
output_size=output_size,
)
model_command = f"fly -c {model_config.checkpoint_path} -ch {','.join(model_config.channels)} -ivs {','.join(map(str,model_config.input_voxel_size))} -ovs {','.join(map(str,model_config.output_voxel_size))}"
command = f"{SERVER_COMMAND} {model_command} -d {data_path}"
Expand Down
11 changes: 10 additions & 1 deletion cellmap_flow/image_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ def __init__(
normalize=True,
):
# if multiscale dataset, get scale for voxel size
if not isinstance(zarr.open(dataset_path, mode="r"), zarr.core.Array):
logger.error(f"opening dataset path {dataset_path} in mode {mode}")
if ".n5" in dataset_path:
container_path = dataset_path[: dataset_path.rfind(".n5") + 3]
ds_path = dataset_path[dataset_path.rfind(".n5") + 4 :]
store = zarr.N5Store(container_path)
dd = zarr.open(store, mode="r")
dd = dd[ds_path]
else:
dd = zarr.open(dataset_path, mode="r")
if not isinstance(dd, zarr.core.Array):
scale, _, _ = find_closest_scale(dataset_path, voxel_size)
logger.info(f"found scale {scale} for voxel size {voxel_size}")
dataset_path = os.path.join(dataset_path, scale)
Expand Down
8 changes: 4 additions & 4 deletions cellmap_flow/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def predict(read_roi, write_roi, config, **kwargs):


class Inferencer:
def __init__(self, model_config: ModelConfig, use_half_prediction=True):
def __init__(self, model_config: ModelConfig, use_half_prediction=False):

if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
logger.error("No GPU available, using CPU")
torch.backends.cudnn.allow_tf32 = True # May help performance with newer cuDNN
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True # Find best algorithm for the hardware
# torch.backends.cudnn.allow_tf32 = True # May help performance with newer cuDNN
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True # Find best algorithm for the hardware

self.use_half_prediction = use_half_prediction
self.model_config = model_config
Expand Down
41 changes: 41 additions & 0 deletions cellmap_flow/utils/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,47 @@
from cellmap_flow.globals import g


def generate_singlescale_metadata(
arr_name: str,
voxel_size: list,
translation: list,
units: str,
axes: list,
):
z_attrs: dict = {"multiscales": [{}]}

# Create axes with proper types - channel axis should have type "channel"
axes_list = []
for axis, unit in zip(axes, units):
if axis in ["c", "c^"]:
axes_list.append({"name": axis, "type": "channel"})
else:
axes_list.append({"name": axis, "type": "space", "unit": unit})

z_attrs["multiscales"][0]["axes"] = axes_list

# Set coordinateTransformations scale to match dimensionality
scale_transform = [1.0] * len(axes)
z_attrs["multiscales"][0]["coordinateTransformations"] = [
{"scale": scale_transform, "type": "scale"}
]

z_attrs["multiscales"][0]["datasets"] = [
{
"coordinateTransformations": [
{"scale": voxel_size, "type": "scale"},
{"translation": translation, "type": "translation"},
],
"path": arr_name,
}
]

z_attrs["multiscales"][0]["name"] = ""
z_attrs["multiscales"][0]["version"] = "0.4"

return z_attrs


def get_scale_info(zarr_grp):
attrs = zarr_grp.attrs
resolutions = {}
Expand Down
5 changes: 4 additions & 1 deletion cellmap_flow/utils/neuroglancer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def generate_neuroglancer_url(dataset_path, extras=[]):
filetype = "precomputed"
s.layers["data"] = neuroglancer.ImageLayer(
source=f"{filetype}://{dataset_path}",
shader="""#uicontrol invlerp normalized(range=[-1, 1], window=[-1, 1]);
#uicontrol vec3 color color(default="white");
void main(){{emitRGB(color * normalized());}}""",
)
for i, extra in enumerate(extras):
logger.error(f" adding extra {i} {extra}")
Expand Down Expand Up @@ -68,7 +71,7 @@ def generate_neuroglancer_url(dataset_path, extras=[]):
color = next(color_cycle)
s.layers[model] = neuroglancer.ImageLayer(
source=f"n5://{host}/{model}{ARGS_KEY}{st_data}{ARGS_KEY}",
shader=f"""#uicontrol invlerp normalized(range=[0, 1], window=[0, 1]);
shader=f"""#uicontrol invlerp normalized(range=[0.5, 0.5], window=[0, 1]);
#uicontrol vec3 color color(default="{color}");
void main(){{emitRGB(color * normalized());}}""",
)
Expand Down
5 changes: 4 additions & 1 deletion cellmap_flow/utils/scale_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_raw_layer(dataset_path, normalize=True):
scales=image.voxel_size,
),
voxel_offset=image.offset,
)
),
shader="""#uicontrol invlerp normalized(range=[-1, 1], window=[-1, 1]);
#uicontrol vec3 color color(default="white");
void main(){{emitRGB(color * normalized());}}""",
)


Expand Down