Skip to content

Commit

Permalink
add xeno-learning project
Browse files Browse the repository at this point in the history
  • Loading branch information
JanSellner committed Oct 20, 2024
1 parent 654965a commit 973d051
Show file tree
Hide file tree
Showing 52 changed files with 496,753 additions and 0 deletions.
3 changes: 3 additions & 0 deletions htc_projects/rat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

190 changes: 190 additions & 0 deletions htc_projects/rat/settings_rat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

import os

from htc.settings import settings
from htc.utils.LabelMapping import LabelMapping
from htc.utils.MultiPath import MultiPath
from htc.utils.unify_path import unify_path


class SettingsRat:
def __init__(self):
self.label_mapping = LabelMapping(
{
"stomach": 0,
"small_bowel": 1,
"colon": 2,
"liver": 3,
"pancreas": 4,
"kidney": 5,
"spleen": 6,
"bladder": 7,
"omentum": 8,
"lung": 9,
"heart": 10,
"cartilage": 11,
"bone": 12,
"skin": 13,
"muscle": 14,
"peritoneum": 15,
"major_vein": 16,
"kidney_with_Gerotas_fascia": 17,
},
unknown_invalid=True,
)

self.label_mapping_standardized = LabelMapping(
{
"stomach": 0,
"small_bowel": 1,
"colon": 2,
"liver": 3,
"pancreas": 4,
"kidney": 5,
"spleen": 6,
"bladder": 7,
"omentum": 8,
"lung": 9,
"pleura": 10,
"trachea": 11,
"heart": 12,
"cartilage": 13,
"bone": 14,
"tendon": 15,
"ligament_pat": 16,
"skin": 17,
"fur": 18,
"muscle": 19,
"fat_subcutaneous": 20,
"peritoneum": 21,
"aorta": 22,
"major_vein": 23,
"kidney_with_Gerotas_fascia": 24,
"diaphragm": 25,
"tube": 26,
"ovary": 27,
"fat_visceral": 28,
"thymus": 29,
"blood": 30,
},
unknown_invalid=True,
)

# Only for those subjects, we have standardized recordings for all organs (there are more subjects with standardized recordings but not for all organs)
self.standardized_subjects = [
"R002",
"R003",
"R014",
"R015",
"R016",
"R017",
"R018",
"R019",
"R020",
"R021",
"R022",
"R023",
"R024",
]

self.best_run_standardized = "2024-02-23_14-31-38_median_31classes"

self.labels_paper_renaming = {
"small_bowel": "small bowel",
"major_vein": "vena cava",
"kidney_with_Gerotas_fascia": "kidney with\nGerota's fascia",
"fat_visceral": "visceral fat",
"ligament_pat": "ligament",
"fat_subcutaneous": "subcutaneous tissue",
"saliv_gland": "salivary gland",
"vesic_gland": "vesicular gland",
}

self.label_colors = {
"stomach": "#FF1202",
"small_bowel": "#FF9001",
"colon": "#FFDD00",
"liver": "#7FFD03",
"pancreas": "#02FFF2",
"kidney": "#0475FF",
"spleen": "#020197",
"bladder": "#630605",
"omentum": "#9900ED",
"lung": "#ED00C9",
"heart": "#FD8EEC",
"cartilage": "#15E7C5",
"bone": "#A35F01",
"skin": "#A32121",
"muscle": "#484848",
"peritoneum": "#8C7FB8",
"major_vein": "#BE14C4",
"kidney_with_Gerotas_fascia": "#BEE7C5",
"tube": "#BDE70A",
"ovary": "#AB8600",
"aorta": "#AB8600",
"pleura": "#FFF893",
"blood": "#830000",
"fat_visceral": "#FFC494",
"tendon": "#89BDFF",
"ligament_pat": "#FFB46D",
"thymus": "#D88CFC",
"trachea": "#00E28E",
"fur": "#FF7830",
"fat_subcutaneous": "#E66E6E",
"diaphragm": "#73AF00",
"thyroid": "#B90C00",
"saliv_gland": "#BC6A00",
"vesic_gland": "#00469C",
"teeth": "#448801",
"urine": "#FFEF88",
}

self.colormap_straylight = {
"no_straylight": "#688B51",
"ceiling": "#4D6DA1",
"OR-right": "#8B5958",
"OR-situs": "#9E50A1",
"OR-situs+ceiling": "#604961",
}

self._results_dir = None

self.colormap_straylight = {
"no_straylight": "#688B51",
"ceiling": "#4D6DA1",
"OR-right": "#8B5958",
"OR-situs": "#9E50A1",
"OR-situs+ceiling": "#604961",
}

@property
def results_dir(self) -> MultiPath:
if self._results_dir is None:
if _path_env := os.getenv("PATH_HTC_RESULTS_RAT", False):
self._results_dir = unify_path(_path_env)
else:
# If no path is set, we just use the default results directory
self._results_dir = settings.results_dir
settings.log.info(
"The environment variable PATH_HTC_RESULTS_RAT is not set. Files for the rat project"
f" will be written to {self._results_dir.find_best_location()}"
)

return self._results_dir

@property
def figures_dir(self) -> MultiPath:
target_dir = self.results_dir / "figures"
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir

@property
def paper_dir(self) -> MultiPath:
target_dir = self.results_dir / "paper"
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir


settings_rat = SettingsRat()
31 changes: 31 additions & 0 deletions htc_projects/rat/tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

import pandas as pd

from htc.utils.helper_functions import median_table, sort_labels
from htc.utils.LabelMapping import LabelMapping
from htc_projects.rat.settings_rat import settings_rat


def standardized_recordings_rat(label_mapping: LabelMapping, Camera_CamID: str = None):
"""
Returns the selection of data corresponding to the standardized recordings of the rat dataset.
Args:
label_mapping: The selection of labels to use.
Camera_CamID: If not None, will make the selection on the given camera.
Returns: Table with median spectra.
"""
df = median_table("2023_12_07_Tivita_multiorgan_rat", label_mapping=label_mapping)
df.drop(columns=["label_name"], inplace=True)
df.rename(columns={"label_name_mapped": "label_name"}, inplace=True)

df = df[df["subject_name"].isin(settings_rat.standardized_subjects)]
if Camera_CamID is not None:
df = df[df["Camera_CamID"] == Camera_CamID]
df = df.loc[(~pd.isna(df[["situs", "angle", "repetition"]])).any(axis=1)] # Only standardized recordings

df = sort_labels(df)
return df.reset_index(drop=True)
332,907 changes: 332,907 additions & 0 deletions htc_projects/species/ProjectionExample.ipynb

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions htc_projects/species/ProjectionLearner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

import torch
import torch.nn as nn

from htc.models.image.DatasetImage import DatasetImage
from htc.tivita.DataPath import DataPath
from htc.utils.Config import Config
from htc.utils.specular_highlights import specular_highlights_mask_lab


class ProjectionLearner(nn.Module):
def __init__(self, config: Config, mode: str = "weights+bias", highlights_threshold: int = None):
"""
This class can be used to learn a projection matrix that maps the spectra from one image to the spectra of another image. This is useful if the images show the same object but in different stats, for example physiological and ischemic states.
The number of pixels do not have to be the same in both images since this optimization is carried out indirectly by enforcing that the mean and standard deviation of the spectra as well as the histogram of values are similar in both images. See the ProjectionExample.ipynb notebook for an example of the usage of this class.
Args:
config: The configuration object which is used to load the spectral data and the valid pixels.
mode: The general mode of the projection: set to "weights" to only use a projection matrix, "bias" to only use a bias vector, "weights+bias" to use both.
highlights_threshold: An optional threshold to filter out specular highlight pixels in case you want to exclude them from the optimizations.
"""
super().__init__()
self.config = config
self.mode = mode
self.highlights_threshold = highlights_threshold

self.projection_matrix = nn.Parameter(torch.eye(100, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(100, dtype=torch.float32))

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mode == "weights":
x = x @ self.projection_matrix
elif self.mode == "bias":
x = x + self.bias
elif self.mode == "weights+bias":
x = x @ self.projection_matrix + self.bias
else:
raise ValueError(f"Unknown mode: {self.mode}")

return x

def fit_pair(self, path_from: DataPath, path_to: DataPath, n_steps: int = 100) -> float:
spectra_from = self._load_spectra(path_from)
spectra_to = self._load_spectra(path_to)

optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
mse_loss = torch.nn.MSELoss()

min_hist = spectra_to.min()
max_hist = spectra_to.max()
n_bins = 50
hist_to = torch.histc(spectra_to, bins=n_bins, min=min_hist, max=max_hist) / spectra_to.numel()

spectra_to_mean = spectra_to.mean(dim=0)
spectra_to_std = spectra_to.std(dim=0)

for _ in range(n_steps):
spectra_transformed = self(spectra_from)
hist_transformed = (
torch.histc(spectra_transformed, bins=n_bins, min=min_hist, max=max_hist) / spectra_transformed.numel()
)

loss = mse_loss(spectra_transformed.mean(dim=0), spectra_to_mean)
loss += mse_loss(spectra_transformed.std(dim=0), spectra_to_std)
loss += mse_loss(hist_transformed, hist_to)

loss.backward()
optimizer.step()
optimizer.zero_grad()

return loss.item()

def _load_spectra(self, path: DataPath) -> torch.Tensor:
sample = DatasetImage([path], train=False, config=self.config)[0]
valid_pixels = sample["valid_pixels"]

if self.highlights_threshold is not None:
highlights = specular_highlights_mask_lab(path, threshold=self.highlights_threshold)
valid_pixels.masked_fill_(highlights, False)

spectra = sample["features"][valid_pixels].to(
dtype=self.projection_matrix.dtype, device=self.projection_matrix.device
)

return spectra
3 changes: 3 additions & 0 deletions htc_projects/species/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

55 changes: 55 additions & 0 deletions htc_projects/species/apply_transforms_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: 2022 Division of Intelligent Medical Systems, DKFZ
# SPDX-License-Identifier: MIT

import pandas as pd
from lightning import seed_everything

from htc.models.common.torch_helpers import copy_sample
from htc.models.common.transforms import HTCTransformation
from htc.models.image.DatasetImageBatch import DatasetImageBatch
from htc.tivita.DataPath import DataPath
from htc.utils.Config import Config
from htc.utils.LabelMapping import LabelMapping


def apply_transforms_paths_median(paths: list[DataPath], config: Config, epoch_size: int = 1) -> pd.DataFrame:
"""
Apply transformations on some images and compute the median spectra.
Args:
paths: The data paths to apply the transforms to.
config: The configuration object which defines the loading the the transformation of the data (`input/transforms_gpu`).
epoch_size: The number of times the transformation should be applied to the same image (to mimic a similar behavior as during training).
Returns: A table with the computed median spectrum based on the transformed data.
"""
seed_everything(0, workers=True)
dataloader = DatasetImageBatch.batched_iteration(paths, config)
mapping = LabelMapping.from_config(config)
aug = HTCTransformation.parse_transforms(config["input/transforms_gpu"], config=config, device="cuda")

rows = []
for batch in dataloader:
for e in range(epoch_size):
# Apply the transformation multiple times to the same image as also done during training
batch_copy = copy_sample(batch)
batch_copy = HTCTransformation.apply_valid_transforms(batch_copy, aug)

for b in range(batch_copy["features"].size(0)):
for label_index in batch_copy["labels"][b, batch_copy["valid_pixels"][b]].unique():
selection = batch_copy["labels"][b] == label_index
spectra = batch_copy["features"][b][selection]

path = DataPath.from_image_name(batch_copy["image_name"][b])
current_row = {"image_name": path.image_name()}
current_row |= path.image_name_typed()

current_row |= {
"epoch_index": e,
"label_name": mapping.index_to_name(label_index),
"median_normalized_spectrum": spectra.quantile(q=0.5, dim=0).cpu().numpy(),
}

rows.append(current_row)

return pd.DataFrame(rows)
13 changes: 13 additions & 0 deletions htc_projects/species/configs/baseline_human.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"inherits": "htc_projects/context/models/configs/organ_transplantation_0.8.json",
"label_mapping": "htc_projects.species.settings_species>label_mapping",
"input": {
"data_spec": "human_semantic-only_physiological-kidney_5folds_nested-0-2_mapping-12_seed-0.json",
"hierarchical_sampling": "label",
"target_domain": ["no_domain"]
},
"dataloader_kwargs": {
"batch_size": 8,
"num_workers": 2
}
}
Loading

0 comments on commit 973d051

Please sign in to comment.