diff --git a/dacapo/experiments/datasplits/__init__.py b/dacapo/experiments/datasplits/__init__.py index ad1ad4880..f70ec1a71 100644 --- a/dacapo/experiments/datasplits/__init__.py +++ b/dacapo/experiments/datasplits/__init__.py @@ -5,3 +5,4 @@ from .train_validate_datasplit import TrainValidateDataSplit from .train_validate_datasplit_config import TrainValidateDataSplitConfig from .datasplit_generator import DataSplitGenerator, DatasetSpec +from .simple_config import SimpleDataSplitConfig \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/__init__.py b/dacapo/experiments/datasplits/datasets/__init__.py index edcffd8ef..c886eea19 100644 --- a/dacapo/experiments/datasplits/datasets/__init__.py +++ b/dacapo/experiments/datasplits/datasets/__init__.py @@ -4,3 +4,4 @@ from .dummy_dataset_config import DummyDatasetConfig from .raw_gt_dataset import RawGTDataset from .raw_gt_dataset_config import RawGTDatasetConfig +from .simple import SimpleDataset \ No newline at end of file diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset.py b/dacapo/experiments/datasplits/datasets/dummy_dataset.py index b8e6a2ae0..b73f1a051 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset.py @@ -1,6 +1,8 @@ from .dataset import Dataset from funlib.persistence import Array +import warnings + class DummyDataset(Dataset): """ @@ -15,6 +17,7 @@ class DummyDataset(Dataset): Notes: This class is used to create a dataset with raw data. """ + raw: Array @@ -34,5 +37,11 @@ def __init__(self, dataset_config): This method is used to initialize the dataset. """ super().__init__() + + warnings.warn( + "DummyDataset is deprecated. Use SimpleDataset instead.", + DeprecationWarning, + ) + self.name = dataset_config.name self.raw = dataset_config.raw_config.array() diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py index 8af1068f9..6da920ec0 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py @@ -4,6 +4,7 @@ from funlib.geometry import Coordinate from typing import Optional, List +import warnings class RawGTDataset(Dataset): @@ -48,6 +49,12 @@ def __init__(self, dataset_config): Notes: This method is used to initialize the dataset. """ + + warnings.warn( + "RawGTDataset is deprecated. Use SimpleDataset instead.", + DeprecationWarning, + ) + self.name = dataset_config.name self.raw = dataset_config.raw_config.array() self.gt = dataset_config.gt_config.array() diff --git a/dacapo/experiments/datasplits/datasets/simple.py b/dacapo/experiments/datasplits/datasets/simple.py new file mode 100644 index 000000000..5c73c2537 --- /dev/null +++ b/dacapo/experiments/datasplits/datasets/simple.py @@ -0,0 +1,69 @@ +from .dataset_config import DatasetConfig + +from funlib.persistence import Array, open_ds + + +import attr + +from pathlib import Path +import numpy as np + +@attr.s +class SimpleDataset(DatasetConfig): + + path: Path = attr.ib() + weight: int = attr.ib(default=1) + raw_name: str = attr.ib(default="raw") + gt_name: str = attr.ib(default="labels") + mask_name: str = attr.ib(default="mask") + + @staticmethod + def dataset_type(dataset_config): + return dataset_config + + @property + def raw(self) -> Array: + raw_array = open_ds(self.path / self.raw_name) + dtype = raw_array.dtype + if dtype == np.uint8: + raw_array.lazy_op(lambda data: data.astype(np.float32) / 255) + elif dtype == np.uint16: + raw_array.lazy_op(lambda data: data.astype(np.float32) / 65535) + elif np.issubdtype(dtype, np.floating): + pass + elif np.issubdtype(dtype, np.integer): + raise Exception( + f"Not sure how to normalize intensity data with dtype {dtype}" + ) + return raw_array + + @property + def gt(self) -> Array: + return open_ds(self.path / self.gt_name) + + @property + def mask(self) -> Array | None: + mask_path = self.path / self.mask_name + if mask_path.exists(): + mask = open_ds(mask_path) + assert np.issubdtype(mask.dtype, np.integer), "Mask must be integer type" + mask.lazy_op(lambda data: data > 0) + return mask + return None + + @property + def sample_points(self) -> None: + return None + + + def __eq__(self, other) -> bool: + return isinstance(other, type(self)) and self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + + def __repr__(self) -> str: + return self.name + + def __str__(self) -> str: + return self.name \ No newline at end of file diff --git a/dacapo/experiments/datasplits/dummy_datasplit.py b/dacapo/experiments/datasplits/dummy_datasplit.py index b8bde7327..20342040d 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit.py +++ b/dacapo/experiments/datasplits/dummy_datasplit.py @@ -2,6 +2,7 @@ from .datasets import Dataset from typing import List +import warnings class DummyDataSplit(DataSplit): @@ -41,6 +42,10 @@ def __init__(self, datasplit_config): This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets. """ super().__init__() + warnings.warn( + "TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.", + DeprecationWarning, + ) self.train = [ datasplit_config.train_config.dataset_type(datasplit_config.train_config) diff --git a/dacapo/experiments/datasplits/simple_config.py b/dacapo/experiments/datasplits/simple_config.py new file mode 100644 index 000000000..8e65f56b8 --- /dev/null +++ b/dacapo/experiments/datasplits/simple_config.py @@ -0,0 +1,69 @@ +from .datasets.simple import SimpleDataset +from .datasplit_config import DataSplitConfig + +import attr + +from pathlib import Path + +import glob + +@attr.s +class SimpleDataSplitConfig(DataSplitConfig): + """ + A convention over configuration datasplit that can handle many of the most + basic cases. + """ + + path: Path = attr.ib() + name: str = attr.ib() + train_group_name: str = attr.ib(default="train") + validate_group_name: str = attr.ib(default="test") + raw_name: str = attr.ib(default="raw") + gt_name: str = attr.ib(default="labels") + mask_name: str = attr.ib(default="mask") + + @staticmethod + def datasplit_type(datasplit_config): + return datasplit_config + + def get_paths(self, group_name: str) -> list[Path]: + level_0 = f"{self.path}/{self.raw_name}" + level_1 = f"{self.path}/{group_name}/{self.raw_name}" + level_2 = f"{self.path}/{group_name}/**/{self.raw_name}" + level_0_matches = glob.glob(level_0) + level_1_matches = glob.glob(level_1) + level_2_matches = glob.glob(level_2) + if len(level_0_matches) > 0: + assert ( + len(level_1_matches) == len(level_2_matches) == 0 + ), f"Found raw data at {level_0} and {level_1} and {level_2}" + return [Path(x).parent for x in level_0_matches] + elif len(level_1_matches) > 0: + assert ( + len(level_2_matches) == 0 + ), f"Found raw data at {level_1} and {level_2}" + return [Path(x).parent for x in level_1_matches] + elif len(level_2_matches).parent > 0: + return [Path(x) for x in level_2_matches] + + raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}") + + @property + def train(self) -> list[SimpleDataset]: + return [ + SimpleDataset( + name=x.stem, + path=x, + ) + for x in self.get_paths(self.train_group_name) + ] + + @property + def validate(self) -> list[SimpleDataset]: + return [ + SimpleDataset( + name=x.stem, + path=x, + ) + for x in self.get_paths(self.validate_group_name) + ] diff --git a/dacapo/experiments/datasplits/train_validate_datasplit.py b/dacapo/experiments/datasplits/train_validate_datasplit.py index 0b93663a3..abf57e9b4 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit.py @@ -2,6 +2,7 @@ from .datasets import Dataset from typing import List +import warnings class TrainValidateDataSplit(DataSplit): @@ -47,6 +48,10 @@ def __init__(self, datasplit_config): into training and validation datasets. """ super().__init__() + warnings.warn( + "TrainValidateDataSplit is deprecated. Use SimpleDataSplitConfig instead.", + DeprecationWarning, + ) self.train = [ train_config.dataset_type(train_config) diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 18e2df029..08eac748c 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -234,7 +234,7 @@ def to_xarray(self) -> xr.DataArray: "iterations": [ iteration_score.iteration for iteration_score in self.scores ], - "datasets": self.datasets, + "datasets": [d.name for d in self.datasets], "parameters": self.parameters, "criteria": self.criteria, }, diff --git a/dacapo/plot.py b/dacapo/plot.py index 1ac82e965..5589488d7 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -427,7 +427,7 @@ def plot_runs( ) colors_val = itertools.cycle(plt.cm.tab20.colors) for dataset in run.validation_scores.datasets: - dataset_data = validation_score_data.sel(datasets=dataset) + dataset_data = validation_score_data.sel(datasets=dataset.name) include_validation_figure = True x = [score.iteration for score in run.validation_scores.scores] for i, cc in zip(range(dataset_data.data.shape[1]), colors_val): diff --git a/docs/source/data.rst b/docs/source/data.rst new file mode 100644 index 000000000..8d15567a3 --- /dev/null +++ b/docs/source/data.rst @@ -0,0 +1,111 @@ +.. _sec_data: + +Data Formatting +=============== + +Overview +-------- + +We support any data format that can be opened with the `zarr.open` convenience function from +`zarr `_. We also expect some specific metadata to come +with the data. + +Metadata +-------- + +- `voxel_size`: The size of each voxel in the dataset. This is expected to be a tuple of ints + with the same length as the number of spatial dimensions in the dataset. +- `offset`: The offset of the dataset. This is expected to be a tuple of ints with the same length + as the number of spatial dimensions in the dataset. +- `axis_names`: The name of each axis. This is expected to be a tuple of strings with the same length + as the total number of dimensions in the dataset. For example a 3D dataset with channels would have + `axis_names=('c^', 'z', 'y', 'x')`. Note we expect non-spatial dimensions to include a "^" character. + See [1]_ for expected future changes +- `units`: The units of each axis. This is expected to be a tuple of strings with the same length + as the number of spatial dimensions in the dataset. For example a 3D dataset with channels would have + `units=('nanometers', 'nanometers', 'nanometers')`. + +Orgnaization +------------ + +Ideally all of your data will be contained in a single zarr container. +The simplest possible dataset would look like this: +:: + + data.zarr + ├── raw + └── labels + +If this is what your data looks like, then your data configuration will look like this: + +.. code-block:: + :caption: A simple data configuration + + data_config = DataConfig( + path="/path/to/data.zarr" + ) + +Note that a lot of assumptions will be made. + +1. We assume your raw data is normalized based on the `dtype`. I.e. if your data is + stored as an unsigned int (we recommend uint8) we will assume a range and normalize + it to [0,1] by dividing by the appropriate value (255 for `uint8` or 65535 for `uint16`). + If your data is stored as any `float` we will assume it is already in the range [0, 1]. +2. We assume your labels are stored as unsigned integers. If you want to generate instance segmentations, you will need + to assign a unique id to every object of the class you are interested in. If you want semantic segmentations you + can simply assign a unique id to each class. 0 is reserved for the background class. +3. We assume that the labels are provided densely. The entire volume will be used for training. +4. We will be training and validating on the same data. This is not ideal, but it is an ok starting point for testing + and debugging. + +Next we can add a little bit of complexity by seperating train and test data. This can also be handled +by the same data configuration as above since it will detect the presence of the `train` and `test` groups. + +:: + + data.zarr + ├── train + │ ├── raw + │ └── labels + └── test + ├── raw + └── labels + +We can go further with our basic data configuration since this will often not be enough to describe your data. You may have multiple crops and often your data may be +sparsely annotated. The same data configuration from above will also work for the slightly more complicated +dataset below: + +:: + + data.zarr + ├── train + │ ├── crop_01 + │ │ ├── raw + │ │ ├── labels + │ │ └── mask + │ └── crop_02 + │ ├── raw + │ └── labels + └── test + └─ crop_03 + │ ├── raw + │ ├── labels + │ └── mask + └─ crop_04 + ├── raw + └── labels + +Note that `crop_01` and `crop_03` have masks associated with them. We assume a value of `0` in the mask indicates +unknown data. We will never use this data for supervised training, regardless of the corresponding label value. +If multiple test datasets are provided, this will increase the amount of information to review after training. +You will have e.g. `crop_03_voi` and `crop_04_voi` stored in the validation scores. Since we also take care to +save the "best" model checkpoint, you may now double the number of checkpoints saved since the checkpoint that +achieves optimal `voi` on `crop_03` may not be the same as the checkpoint that achieves optimal `voi` on `crop_04`. + +Footnotes +--------- + +.. [1] The specification of axis names is expected to change in the future since we expect to support a `type` field in the future which + can be one of ["time", "space", "{anything-else}"]. Which would allow you to specify dimensions as "channel" + or "batch" or whatever else you want. This will bring us more in line with OME-Zarr and allow us to more easily + handle a larger variety of common data specification formats. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 8861ad601..3d9b1d529 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,7 @@ overview install notebooks/minimal_tutorial + data unet_architectures tutorial docker diff --git a/examples/starter_tutorial/minimal_tutorial.py b/examples/starter_tutorial/minimal_tutorial.py index bf54a893f..677b72331 100644 --- a/examples/starter_tutorial/minimal_tutorial.py +++ b/examples/starter_tutorial/minimal_tutorial.py @@ -184,57 +184,10 @@ # experiments, but is useful for this tutorial. # %% -from dacapo.experiments.datasplits import TrainValidateDataSplitConfig -from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig -from dacapo.experiments.datasplits.datasets.arrays import ( - ZarrArrayConfig, - IntensitiesArrayConfig, -) +from dacapo.experiments.datasplits.simple_config import SimpleDataSplitConfig from funlib.geometry import Coordinate -datasplit_config = TrainValidateDataSplitConfig( - name="example_datasplit", - train_configs=[ - RawGTDatasetConfig( - name="example_dataset", - raw_config=IntensitiesArrayConfig( - name="example_raw_normalized", - source_array_config=ZarrArrayConfig( - name="example_raw", - file_name="cells3d.zarr", - dataset="raw", - ), - min=0, - max=255, - ), - gt_config=ZarrArrayConfig( - name="example_gt", - file_name="cells3d.zarr", - dataset="mask", - ), - ) - ], - validate_configs=[ - RawGTDatasetConfig( - name="example_dataset", - raw_config=IntensitiesArrayConfig( - name="example_raw_normalized", - source_array_config=ZarrArrayConfig( - name="example_raw", - file_name="cells3d.zarr", - dataset="raw", - ), - min=0, - max=255, - ), - gt_config=ZarrArrayConfig( - name="example_gt", - file_name="cells3d.zarr", - dataset="labels", - ), - ) - ], -) +datasplit_config = SimpleDataSplitConfig(name="cells3d", path="cells3d.zarr") datasplit = datasplit_config.datasplit_type(datasplit_config) config_store.store_datasplit_config(datasplit_config) @@ -497,8 +450,8 @@ # break raw = zarr.open(f"{run_path}/validation.zarr/inputs/{dataset}/raw") gt = zarr.open(f"{run_path}/validation.zarr/inputs/{dataset}/gt") - pred_path = f"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/prediction" - out_path = f"{run_path}/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))" + pred_path = f"{run_path}/validation.zarr/{validation_it}/{dataset}/prediction" + out_path = f"{run_path}/validation.zarr/{validation_it}/{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))" output = zarr.open(out_path)[:] prediction = zarr.open(pred_path)[0] c = (raw.shape[2] - gt.shape[1]) // 2