Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mipcandy/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from mipcandy.data.download import download_dataset
from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop
from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \
inspect, ROIDataset
inspect, ROIDataset, RandomROIDataset
from mipcandy.data.io import resample_to_isotropic, load_image, save_image
from mipcandy.data.visualization import visualize2d, visualize3d, overlay
100 changes: 90 additions & 10 deletions mipcandy/data/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class InspectionAnnotation(object):
shape: tuple[int, ...]
foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int]
ids: tuple[int, ...]
foreground_samples: torch.Tensor | None = None

def foreground_shape(self) -> tuple[int, int] | tuple[int, int, int]:
r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2])
Expand All @@ -39,8 +40,11 @@ def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]:
round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5))
return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),)

def to_dict(self) -> dict[str, tuple[int, ...]]:
return asdict(self)
def to_dict(self) -> dict[str, Any]:
d = asdict(self)
if self.foreground_samples is not None:
d["foreground_samples"] = self.foreground_samples.tolist()
return d


class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]):
Expand Down Expand Up @@ -71,7 +75,7 @@ def __len__(self) -> int:

def save(self, path: str | PathLike[str]) -> None:
with open(path, "w") as f:
dump({"background": self._background, "annotations": self._annotations}, f)
dump({"background": self._background, "annotations": [a.to_dict() for a in self._annotations]}, f)

def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], tuple[int, ...]]) -> tuple[
tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
Expand Down Expand Up @@ -212,18 +216,23 @@ def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, to


def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs}
return {k: tuple(v) if isinstance(v, list) and k != "foreground_samples" else v for k, v in pairs}


def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations:
with open(path) as f:
obj = load(f, object_pairs_hook=_lists_to_tuples)
return InspectionAnnotations(dataset, obj["background"], *(
InspectionAnnotation(**row) for row in obj["annotations"]
))
annotations = []
for row in obj["annotations"]:
if row.get("foreground_samples") is not None:
row["foreground_samples"] = torch.tensor(row["foreground_samples"])
annotations.append(InspectionAnnotation(**row))
return InspectionAnnotations(dataset, obj["background"], *annotations)


def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations:
def inspect(dataset: SupervisedDataset, *, background: int = 0, min_foreground_samples: int = 500,
max_foreground_samples: int = 10000, min_percent_coverage: float = 0.01,
console: Console = Console()) -> InspectionAnnotations:
r = []
with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress:
task = progress.add_task("Inspecting dataset...", total=len(dataset))
Expand All @@ -233,8 +242,23 @@ def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console
mins = indices.min(dim=0)[0].tolist()
maxs = indices.max(dim=0)[0].tolist()
bbox = (mins[1], maxs[1], mins[2], maxs[2])
if len(indices) > 0:
if len(indices) <= min_foreground_samples:
foreground_samples = indices
else:
target_samples = min(
max_foreground_samples,
max(min_foreground_samples, int(np.ceil(len(indices) * min_percent_coverage)))
)
sampled_idx = torch.randperm(len(indices))[:target_samples]
foreground_samples = indices[sampled_idx]
else:
foreground_samples = None
r.append(InspectionAnnotation(
label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]), tuple(label.unique())
label.shape[1:],
bbox if label.ndim == 3 else bbox + (mins[3], maxs[3]),
tuple(label.unique()),
foreground_samples
))
return InspectionAnnotations(dataset, background, *r, device=dataset.device())

Expand All @@ -251,8 +275,64 @@ def __len__(self) -> int:

@override
def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self:
return ROIDataset(self._annotations)
return self.__class__(self._annotations, percentile=self._percentile)

@override
def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
return self._annotations.crop_roi(idx, percentile=self._percentile)


class RandomROIDataset(ROIDataset):
def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95,
foreground_oversample_percent: float = 0.33) -> None:
super().__init__(annotations, percentile=percentile)
self._fg_oversample: float = foreground_oversample_percent

def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]:
annotation = self._annotations[idx]
roi_shape = self._annotations.roi_shape(percentile=self._percentile)
roi = []
for dim_size, patch_size in zip(annotation.shape, roi_shape):
left = patch_size // 2
right = patch_size - left
min_center = left
max_center = dim_size - right
center = torch.randint(min_center, max_center + 1, (1,)).item()
roi.append(center - left)
roi.append(center + right)
return tuple(roi)

def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[
int, int, int, int, int, int]:
annotation = self._annotations[idx]
roi_shape = self._annotations.roi_shape(percentile=self._percentile)

if annotation.foreground_samples is None or len(annotation.foreground_samples) == 0:
return self._random_roi(idx)

fg_idx = torch.randint(0, len(annotation.foreground_samples), (1,)).item()
fg_position = annotation.foreground_samples[fg_idx]

roi = []
for fg_pos, dim_size, patch_size in zip(fg_position.tolist(), annotation.shape, roi_shape):
left = patch_size // 2
right = patch_size - left
center = max(left, min(fg_pos, dim_size - right))
roi.append(center - left)
roi.append(center + right)
return tuple(roi)

@override
def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self:
return self.__class__(self._annotations, percentile=self._percentile,
foreground_oversample_percent=self._fg_oversample)

@override
def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
image, label = self._annotations._dataset[idx]
force_fg = torch.rand(1).item() < self._fg_oversample
if force_fg:
roi = self._foreground_guided_random_roi(idx)
else:
roi = self._random_roi(idx)
return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)