forked from KerenLab/CellSighter
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2421176
Showing
18 changed files
with
799 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
|
||
# CellSigther | ||
<img src="./CellSighterLogo.jpg" style="width: 15%; height: 15%"> | ||
|
||
|
||
CellSighter is an ensemble of convolutional neural networks to perform supervised cell classification in multiplexed images. Given a labeled training set, a model can be trained to predict cell classes for new images. | ||
|
||
## Data Preparation | ||
The Data should have the following structure: | ||
* The raw images should be in: {data_path}/CellTypes/data/antibodies | ||
|
||
* Each image should be saved in a format of npz or tif file as a 3D image shaped: HxWxC, C is the number of proteins in the data | ||
|
||
* The segmentation should be in: {data_path}/CellTypes/cells | ||
* For each image there should be a segmentation in a format of npz or tif file, shaped HxW. The segmentation file is a labeled object matrix whereby all pixels belonging to a cell will have the value of their cell id. The cells should be numbered from 1 to number of cells in the image. | ||
|
||
* The labels should be in: {data_path}/CellTypes/cells2labels | ||
* For each image there should be a file in a format of npz (*.npz), such that each row has the label of the cell id as the index of the row. | ||
* Another option is to save it as a txt format (*.txt) each line separated by \n. Such that each row has the label of the cell id as the index of the row. | ||
* Note that if you don't have a label for a cell you should set the label to -1, but all cells should appear in the file. | ||
|
||
* Channels file, a txt file with the names of proteins ordered according to the order of the proteins in the image file. | ||
the names of the proteins should be separated by \n. | ||
|
||
### Notes: | ||
|
||
- The names of the files should be the image id | ||
|
||
- The labels of the cells should be integer numbers. | ||
|
||
## System requirements | ||
1. Access to GPU | ||
2. See Requirements file for libraries and versions. | ||
|
||
|
||
## Training a Model | ||
|
||
1. Prepare the data in the format above | ||
2. Create a folder with the configuration file named "config.json". | ||
See "Preparing configuration file" for more information. | ||
3. Train one model with the following command: | ||
'python train.py --base_path=/path/to/your/folder/with/the/configuration file' | ||
4. In order to run an ensemble, run the command above more than one time in multiple folders. | ||
|
||
### Output file: | ||
1. val_results_{epocNum}.csv - Results on validation set along training. | ||
The file contains the following columns: | ||
pred - prediction label | ||
pred_prob - probability of predicting the label | ||
label - input label to the training | ||
cell_id - cell_id | ||
image_id - image_id | ||
prob_list - list of probabilities per cell type. The index is the cell type. | ||
2. Weights_{epocNum}_count.pth - The weights of the network. | ||
3. event.out.### - tensorboard logs | ||
|
||
## Evaluating the model | ||
|
||
1. Prepare the data in the format above | ||
2. Create a folder with the configuration file named "config.json". | ||
See "Preparing configuration file" for more information. | ||
3. Change the "weight_to_eval" field in the config file to be the path to the weights of the model you trained (Weights_{epocNum}_count.pth). | ||
4. Evaluate one model with the following command: | ||
'python eval.py --base_path=/path/to/your/folder/with/the/configuration file' | ||
5. You should now have a results csv in the folder. | ||
6. In order to run an ensemble just run the command above for each model you trained. making sure to change the weight paths and work on multiple folders one for each model. | ||
You should now have multiple results files. you can combine them as you wish, or use our merging scripts. | ||
|
||
### Output file: | ||
1. val_results - same format as training | ||
2. event.out.### - tensorboard logs | ||
|
||
## Analyze results | ||
|
||
- You can merge the results of the ensemble to one unified results by running the following script: | ||
analyze_results/unified_ensemble.py | ||
In the script you'll need to fill in the list of paths to all the val_results.csv files you got from the ensemble. | ||
The output of the script will be unified results named "merged_ensemble.csv", with the following columns: | ||
pred - prediction label | ||
pred_prob - probability of predicting the label | ||
label - input label | ||
cell_id - cell_id | ||
image_id - image_id | ||
- You can visualize a confusion matrix of the input labels and CellSighter labels by running the following script: analyze_results/confusion_matrix.py | ||
Once you got a csv with results you can view a confusion matrix (assuming you got valid labels as ground truth) by running the script above. You will need to fill in the path to csv results file. The script will generate a confusion matrix and save it a a png file. | ||
|
||
## Preparing configuration file | ||
> { | ||
"crop_input_size": 60 #size of crop that goes into the network. Make sure that it is sufficient to visualize a cell and a fraction of its immediate neighbors. | ||
"crop_size": 128, #size of initial crop before augmentations. This should be ~2-fold the size of the input crop to allow augmentations such as shifts. | ||
"root_dir": "data_path", #path to the data that you've prepared in previous steps | ||
"train_set": ["FOV1", "FOV2", ...], #List of image ids to use as training set | ||
"val_set": ["FOV10", "FOV12", ...], #List of image ids to use as validation/evaluation set | ||
"num_classes": 20, #Number of classes in the data set | ||
"epoch_max": 50, #Number of epochs to train | ||
"lr": 0.001, # learning rate value | ||
"blacklist": [], #channels to not use in the training/validation at all | ||
"channels_path": "", #Path to the proteins list you created in the steps above during data preparation | ||
"weight_to_eval": "", #Path to weights, relevant only for evaluation | ||
"sample_batch": true, #Whether to sample equally from the category in each batch during training | ||
"hierarchy_match": {"0": "B cell", "1": "Myeloid",...} #Dictionary of matching classes to higher category for balancing higher categories during training. The keys should be the label ids and the values the higher categories. | ||
"size_data": 1000, #Optional, for each cell type sample size_data samples or less if there aren't enough cells from the cell type | ||
"aug": true #Optional, whether to apply augmentation or not | ||
} |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
from sklearn.metrics import confusion_matrix | ||
from matplotlib.colors import LinearSegmentedColormap | ||
|
||
|
||
def metric(gt, pred, classes_for_cm, colorbar=True): | ||
sns.set(font_scale=2) | ||
cm_normed_recall = confusion_matrix(gt, pred, labels=classes_for_cm, normalize="true") * 100 | ||
cm = confusion_matrix(gt, pred, labels=classes_for_cm) | ||
|
||
plt.figure(figsize=(50,45)) | ||
ax1 = plt.subplot2grid((50,50), (0,0), colspan=30, rowspan=30) | ||
cmap = LinearSegmentedColormap.from_list('', ['white', *plt.cm.Blues(np.arange(255))]) | ||
annot_labels = cm_normed_recall.round(1).astype(str) | ||
annot_labels = pd.DataFrame(annot_labels) + "\n (" + pd.DataFrame(cm).astype(str)+")" | ||
|
||
annot_mask = cm_normed_recall.round(1) <= 0.1 | ||
annot_labels[annot_mask] = "" | ||
|
||
sns.heatmap(cm_normed_recall.T, ax=ax1, annot=annot_labels.T, fmt='',cbar = colorbar, | ||
cmap=cmap,linewidths=1, vmin=0, vmax=100,linecolor='black', square=True) | ||
|
||
ax1.xaxis.tick_top() | ||
ax1.set_xticklabels(classes_for_cm,rotation=90) | ||
ax1.set_yticklabels(classes_for_cm,rotation=0) | ||
ax1.tick_params(axis='both', which='major', labelsize=35) | ||
|
||
ax1.set_xlabel("Clustering and gating", fontsize=35) | ||
ax1.set_ylabel("CellSighter", fontsize=35) | ||
|
||
|
||
results = pd.read_csv(r"") #Fill in the path to your results file | ||
classes_for_cm = np.unique(np.concatenate([results["label"], results["pred"]])) | ||
metric(results["label"], results["pred"], classes_for_cm) | ||
plt.savefig("confusion_matrix.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pandas as pd | ||
|
||
val_results = [] #Fill here the paths to all of yours val_results.csv files you got from the training/validation | ||
|
||
df_all_labeled = pd.DataFrame() | ||
ensemble_size = len(val_results) | ||
for i, val_result in enumerate(val_results): | ||
curr_df = pd.read_csv(val_result, index_col=0) | ||
prob_list = curr_df["prob_list"].apply(eval) | ||
num_classes = len(prob_list.iloc[0]) | ||
curr_df[[f"prob_class_{j}" for j in range(num_classes)]] = prob_list.apply(pd.Series) | ||
curr_df.columns = [c+f"_ens_{i}" for c in curr_df.columns] | ||
df_all_labeled = pd.concat([df_all_labeled, curr_df], axis=1) | ||
|
||
for i in range(num_classes): | ||
df_all_labeled[f"prob_mean_class_{i}"] = df_all_labeled[[f"prob_class_{i}_ens_{j}" for j in range(ensemble_size)]].mean(axis=1) | ||
|
||
df_all_labeled["pred"] = df_all_labeled[[f"prob_mean_class_{i}" for i in range(num_classes)]].values.argmax(1) | ||
df_all_labeled["pred_prob"] = df_all_labeled[[f"prob_mean_class_{i}" for i in range(num_classes)]].max(axis=1) | ||
|
||
df_all_labeled["label"] = df_all_labeled["label_ens_1"] | ||
df_all_labeled["cell_id"] = df_all_labeled["cell_id_ens_1"] | ||
df_all_labeled["image_id"] = df_all_labeled["image_id_ens_1"] | ||
|
||
df_all_labeled[["image_id", "cell_id", "label", "pred", "pred_prob"]].to_csv("merged_ensemble.csv") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class CellCrop: | ||
""" | ||
Represent crop of cell with all the channels, mask of cell, mask cells in the environment and more | ||
""" | ||
def __init__(self, cell_id, image_id, label, slices, cells, image): | ||
self._cell_id = cell_id | ||
self._image_id = image_id | ||
self._label = label | ||
self._slices = slices | ||
self._cells = cells | ||
self._image = image | ||
|
||
def sample(self, mask=False): | ||
result = {'cell_id': self._cell_id, 'image_id': self._image_id, | ||
'image': self._image[self._slices].astype(np.float32), | ||
'slice_x_start': self._slices[0].start, | ||
'slice_y_start': self._slices[1].start, | ||
'slice_x_end': self._slices[0].stop, | ||
'slice_y_end': self._slices[1].stop, | ||
'label': self._label.astype(np.long)} | ||
if mask: | ||
result['mask'] = (self._cells[self._slices] == self._cell_id).astype(np.float32) | ||
result['all_cells_mask'] = (self._cells[self._slices] > 0).astype(np.float32) | ||
result['all_cells_mask_seperate'] = (self._cells[self._slices]).astype(np.float32) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class CellCropsDataset(Dataset): | ||
def __init__(self, | ||
crops, | ||
mask=False, | ||
transform=None): | ||
super().__init__() | ||
self._crops = crops | ||
self._transform = transform | ||
self._mask = mask | ||
|
||
def __len__(self): | ||
return len(self._crops) | ||
|
||
def __getitem__(self, idx): | ||
sample = self._crops[idx].sample(self._mask) | ||
aug = self._transform(np.dstack( | ||
[sample['image'], sample['all_cells_mask'][:, :, np.newaxis], sample['mask'][:, :, np.newaxis]])).float() | ||
sample['image'] = aug[:-1, :, :] | ||
sample['mask'] = aug[[-1], :, :] | ||
|
||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torchvision.transforms import Lambda, RandomCrop, CenterCrop | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class ShiftAugmentation(torch.nn.Module): | ||
""" | ||
Augmentation that shift each marker channel a few pixels in random direction | ||
""" | ||
def __init__(self, n_size, shift_max=0): | ||
super(ShiftAugmentation, self).__init__() | ||
self.shift_max = shift_max | ||
self.n_size = n_size | ||
p = 0.3 | ||
|
||
self.chanel_shifter = Lambda(lambda x: | ||
RandomCrop(size=n_size)( | ||
CenterCrop(size=n_size + (self.shift_max if np.random.random() < p else 0))(x))) | ||
|
||
def forward(self, x): | ||
# X is shaped: (C, H, W) | ||
aug_x = torch.zeros((x.shape[0], self.n_size, self.n_size)) | ||
for i in range(x.shape[0]): | ||
aug_x[i, :, :] = self.chanel_shifter(x[[i], :, :])[0,:,:] | ||
return aug_x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import cv2 | ||
import numpy as np | ||
import torchvision | ||
from torchvision.transforms import Lambda | ||
from data.shift_augmentation import ShiftAugmentation | ||
|
||
|
||
def poisson_sampling(x): | ||
""" | ||
Augmentation that resample the data from poisson distribution. | ||
Args: | ||
x: (H,W,C) when C is th number of markers + 2. (one for the mask of the cell. | ||
one for the mask of all the other cells in the environment) | ||
Returns: | ||
Augmented tensor size of (H,W,C) resampled from poisson distribution. | ||
""" | ||
blur = cv2.GaussianBlur(x[:, :, :-2], (5, 5), 0) | ||
x[:, :, :-2] = np.random.poisson(lam=blur, size=x[:, :, :-2].shape) | ||
return x | ||
|
||
|
||
def cell_shape_aug(x): | ||
""" | ||
Augment the mask of the cell size by dilating the size of the cell with random kernel | ||
""" | ||
if np.random.random() < 0.5: | ||
cell_mask = x[:, :, -1] | ||
kernel_size = np.random.choice([2, 3, 5]) | ||
kernel = np.ones(kernel_size, np.uint8) | ||
img_dilation = cv2.dilate(cell_mask, kernel, iterations=1) | ||
x[:, :, -1] = img_dilation | ||
return x | ||
|
||
|
||
def env_shape_aug(x): | ||
""" | ||
Augment the size of the cells mask in the environment, | ||
by dilating the size of the cell with random kernel | ||
""" | ||
if np.random.random() < 0.5: | ||
cell_mask = x[:, :, -2] | ||
kernel_size = np.random.choice([2, 3, 5]) | ||
kernel = np.ones(kernel_size, np.uint8) | ||
img_dilation = cv2.dilate(cell_mask, kernel, iterations=1) | ||
x[:, :, -2] = img_dilation | ||
return x | ||
|
||
|
||
val_transform = lambda crop_size: torchvision.transforms.Compose([ | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.CenterCrop((crop_size, crop_size)) | ||
]) | ||
|
||
train_transform = lambda crop_size, shift: torchvision.transforms.Compose([ | ||
torchvision.transforms.Lambda(poisson_sampling), | ||
torchvision.transforms.Lambda(cell_shape_aug), | ||
torchvision.transforms.Lambda(env_shape_aug), | ||
torchvision.transforms.ToTensor(), | ||
torchvision.transforms.RandomRotation(degrees=(0, 360)), | ||
Lambda(lambda x: ShiftAugmentation(shift_max=shift, n_size=crop_size)(x) if np.random.random() < 0.5 else x), | ||
torchvision.transforms.CenterCrop((crop_size, crop_size)), | ||
torchvision.transforms.RandomHorizontalFlip(p=0.75), | ||
torchvision.transforms.RandomVerticalFlip(p=0.75), | ||
]) |
Oops, something went wrong.