Skip to content

Commit

Permalink
init repo
Browse files Browse the repository at this point in the history
  • Loading branch information
yaelAmitay committed Nov 5, 2022
0 parents commit 2421176
Show file tree
Hide file tree
Showing 18 changed files with 799 additions and 0 deletions.
Binary file added CellSighterLogo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@

# CellSigther
<img src="./CellSighterLogo.jpg" style="width: 15%; height: 15%">


CellSighter is an ensemble of convolutional neural networks to perform supervised cell classification in multiplexed images. Given a labeled training set, a model can be trained to predict cell classes for new images.

## Data Preparation
The Data should have the following structure:
* The raw images should be in: {data_path}/CellTypes/data/antibodies

* Each image should be saved in a format of npz or tif file as a 3D image shaped: HxWxC, C is the number of proteins in the data

* The segmentation should be in: {data_path}/CellTypes/cells
* For each image there should be a segmentation in a format of npz or tif file, shaped HxW. The segmentation file is a labeled object matrix whereby all pixels belonging to a cell will have the value of their cell id. The cells should be numbered from 1 to number of cells in the image.

* The labels should be in: {data_path}/CellTypes/cells2labels
* For each image there should be a file in a format of npz (*.npz), such that each row has the label of the cell id as the index of the row.
* Another option is to save it as a txt format (*.txt) each line separated by \n. Such that each row has the label of the cell id as the index of the row.
* Note that if you don't have a label for a cell you should set the label to -1, but all cells should appear in the file.

* Channels file, a txt file with the names of proteins ordered according to the order of the proteins in the image file.
the names of the proteins should be separated by \n.

### Notes:

- The names of the files should be the image id

- The labels of the cells should be integer numbers.

## System requirements
1. Access to GPU
2. See Requirements file for libraries and versions.


## Training a Model

1. Prepare the data in the format above
2. Create a folder with the configuration file named "config.json".
See "Preparing configuration file" for more information.
3. Train one model with the following command:
'python train.py --base_path=/path/to/your/folder/with/the/configuration file'
4. In order to run an ensemble, run the command above more than one time in multiple folders.

### Output file:
1. val_results_{epocNum}.csv - Results on validation set along training.
The file contains the following columns:
pred - prediction label
pred_prob - probability of predicting the label
label - input label to the training
cell_id - cell_id
image_id - image_id
prob_list - list of probabilities per cell type. The index is the cell type.
2. Weights_{epocNum}_count.pth - The weights of the network.
3. event.out.### - tensorboard logs

## Evaluating the model

1. Prepare the data in the format above
2. Create a folder with the configuration file named "config.json".
See "Preparing configuration file" for more information.
3. Change the "weight_to_eval" field in the config file to be the path to the weights of the model you trained (Weights_{epocNum}_count.pth).
4. Evaluate one model with the following command:
'python eval.py --base_path=/path/to/your/folder/with/the/configuration file'
5. You should now have a results csv in the folder.
6. In order to run an ensemble just run the command above for each model you trained. making sure to change the weight paths and work on multiple folders one for each model.
You should now have multiple results files. you can combine them as you wish, or use our merging scripts.

### Output file:
1. val_results - same format as training
2. event.out.### - tensorboard logs

## Analyze results

- You can merge the results of the ensemble to one unified results by running the following script:
analyze_results/unified_ensemble.py
In the script you'll need to fill in the list of paths to all the val_results.csv files you got from the ensemble.
The output of the script will be unified results named "merged_ensemble.csv", with the following columns:
pred - prediction label
pred_prob - probability of predicting the label
label - input label
cell_id - cell_id
image_id - image_id
- You can visualize a confusion matrix of the input labels and CellSighter labels by running the following script: analyze_results/confusion_matrix.py
Once you got a csv with results you can view a confusion matrix (assuming you got valid labels as ground truth) by running the script above. You will need to fill in the path to csv results file. The script will generate a confusion matrix and save it a a png file.

## Preparing configuration file
> {
"crop_input_size": 60 #size of crop that goes into the network. Make sure that it is sufficient to visualize a cell and a fraction of its immediate neighbors.
"crop_size": 128, #size of initial crop before augmentations. This should be ~2-fold the size of the input crop to allow augmentations such as shifts.
"root_dir": "data_path", #path to the data that you've prepared in previous steps
"train_set": ["FOV1", "FOV2", ...], #List of image ids to use as training set
"val_set": ["FOV10", "FOV12", ...], #List of image ids to use as validation/evaluation set
"num_classes": 20, #Number of classes in the data set
"epoch_max": 50, #Number of epochs to train
"lr": 0.001, # learning rate value
"blacklist": [], #channels to not use in the training/validation at all
"channels_path": "", #Path to the proteins list you created in the steps above during data preparation
"weight_to_eval": "", #Path to weights, relevant only for evaluation
"sample_batch": true, #Whether to sample equally from the category in each batch during training
"hierarchy_match": {"0": "B cell", "1": "Myeloid",...} #Dictionary of matching classes to higher category for balancing higher categories during training. The keys should be the label ids and the values the higher categories.
"size_data": 1000, #Optional, for each cell type sample size_data samples or less if there aren't enough cells from the cell type
"aug": true #Optional, whether to apply augmentation or not
}
Empty file added __init__.py
Empty file.
Empty file added analyze_results/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions analyze_results/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from matplotlib.colors import LinearSegmentedColormap


def metric(gt, pred, classes_for_cm, colorbar=True):
sns.set(font_scale=2)
cm_normed_recall = confusion_matrix(gt, pred, labels=classes_for_cm, normalize="true") * 100
cm = confusion_matrix(gt, pred, labels=classes_for_cm)

plt.figure(figsize=(50,45))
ax1 = plt.subplot2grid((50,50), (0,0), colspan=30, rowspan=30)
cmap = LinearSegmentedColormap.from_list('', ['white', *plt.cm.Blues(np.arange(255))])
annot_labels = cm_normed_recall.round(1).astype(str)
annot_labels = pd.DataFrame(annot_labels) + "\n (" + pd.DataFrame(cm).astype(str)+")"

annot_mask = cm_normed_recall.round(1) <= 0.1
annot_labels[annot_mask] = ""

sns.heatmap(cm_normed_recall.T, ax=ax1, annot=annot_labels.T, fmt='',cbar = colorbar,
cmap=cmap,linewidths=1, vmin=0, vmax=100,linecolor='black', square=True)

ax1.xaxis.tick_top()
ax1.set_xticklabels(classes_for_cm,rotation=90)
ax1.set_yticklabels(classes_for_cm,rotation=0)
ax1.tick_params(axis='both', which='major', labelsize=35)

ax1.set_xlabel("Clustering and gating", fontsize=35)
ax1.set_ylabel("CellSighter", fontsize=35)


results = pd.read_csv(r"") #Fill in the path to your results file
classes_for_cm = np.unique(np.concatenate([results["label"], results["pred"]]))
metric(results["label"], results["pred"], classes_for_cm)
plt.savefig("confusion_matrix.png")
25 changes: 25 additions & 0 deletions analyze_results/unified_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd

val_results = [] #Fill here the paths to all of yours val_results.csv files you got from the training/validation

df_all_labeled = pd.DataFrame()
ensemble_size = len(val_results)
for i, val_result in enumerate(val_results):
curr_df = pd.read_csv(val_result, index_col=0)
prob_list = curr_df["prob_list"].apply(eval)
num_classes = len(prob_list.iloc[0])
curr_df[[f"prob_class_{j}" for j in range(num_classes)]] = prob_list.apply(pd.Series)
curr_df.columns = [c+f"_ens_{i}" for c in curr_df.columns]
df_all_labeled = pd.concat([df_all_labeled, curr_df], axis=1)

for i in range(num_classes):
df_all_labeled[f"prob_mean_class_{i}"] = df_all_labeled[[f"prob_class_{i}_ens_{j}" for j in range(ensemble_size)]].mean(axis=1)

df_all_labeled["pred"] = df_all_labeled[[f"prob_mean_class_{i}" for i in range(num_classes)]].values.argmax(1)
df_all_labeled["pred_prob"] = df_all_labeled[[f"prob_mean_class_{i}" for i in range(num_classes)]].max(axis=1)

df_all_labeled["label"] = df_all_labeled["label_ens_1"]
df_all_labeled["cell_id"] = df_all_labeled["cell_id_ens_1"]
df_all_labeled["image_id"] = df_all_labeled["image_id_ens_1"]

df_all_labeled[["image_id", "cell_id", "label", "pred", "pred_prob"]].to_csv("merged_ensemble.csv")
Empty file added data/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions data/cell_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import cv2
import numpy as np


class CellCrop:
"""
Represent crop of cell with all the channels, mask of cell, mask cells in the environment and more
"""
def __init__(self, cell_id, image_id, label, slices, cells, image):
self._cell_id = cell_id
self._image_id = image_id
self._label = label
self._slices = slices
self._cells = cells
self._image = image

def sample(self, mask=False):
result = {'cell_id': self._cell_id, 'image_id': self._image_id,
'image': self._image[self._slices].astype(np.float32),
'slice_x_start': self._slices[0].start,
'slice_y_start': self._slices[1].start,
'slice_x_end': self._slices[0].stop,
'slice_y_end': self._slices[1].stop,
'label': self._label.astype(np.long)}
if mask:
result['mask'] = (self._cells[self._slices] == self._cell_id).astype(np.float32)
result['all_cells_mask'] = (self._cells[self._slices] > 0).astype(np.float32)
result['all_cells_mask_seperate'] = (self._cells[self._slices]).astype(np.float32)

return result
25 changes: 25 additions & 0 deletions data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
from torch.utils.data import Dataset


class CellCropsDataset(Dataset):
def __init__(self,
crops,
mask=False,
transform=None):
super().__init__()
self._crops = crops
self._transform = transform
self._mask = mask

def __len__(self):
return len(self._crops)

def __getitem__(self, idx):
sample = self._crops[idx].sample(self._mask)
aug = self._transform(np.dstack(
[sample['image'], sample['all_cells_mask'][:, :, np.newaxis], sample['mask'][:, :, np.newaxis]])).float()
sample['image'] = aug[:-1, :, :]
sample['mask'] = aug[[-1], :, :]

return sample
25 changes: 25 additions & 0 deletions data/shift_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torchvision.transforms import Lambda, RandomCrop, CenterCrop
import numpy as np
import torch


class ShiftAugmentation(torch.nn.Module):
"""
Augmentation that shift each marker channel a few pixels in random direction
"""
def __init__(self, n_size, shift_max=0):
super(ShiftAugmentation, self).__init__()
self.shift_max = shift_max
self.n_size = n_size
p = 0.3

self.chanel_shifter = Lambda(lambda x:
RandomCrop(size=n_size)(
CenterCrop(size=n_size + (self.shift_max if np.random.random() < p else 0))(x)))

def forward(self, x):
# X is shaped: (C, H, W)
aug_x = torch.zeros((x.shape[0], self.n_size, self.n_size))
for i in range(x.shape[0]):
aug_x[i, :, :] = self.chanel_shifter(x[[i], :, :])[0,:,:]
return aug_x
65 changes: 65 additions & 0 deletions data/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import cv2
import numpy as np
import torchvision
from torchvision.transforms import Lambda
from data.shift_augmentation import ShiftAugmentation


def poisson_sampling(x):
"""
Augmentation that resample the data from poisson distribution.
Args:
x: (H,W,C) when C is th number of markers + 2. (one for the mask of the cell.
one for the mask of all the other cells in the environment)
Returns:
Augmented tensor size of (H,W,C) resampled from poisson distribution.
"""
blur = cv2.GaussianBlur(x[:, :, :-2], (5, 5), 0)
x[:, :, :-2] = np.random.poisson(lam=blur, size=x[:, :, :-2].shape)
return x


def cell_shape_aug(x):
"""
Augment the mask of the cell size by dilating the size of the cell with random kernel
"""
if np.random.random() < 0.5:
cell_mask = x[:, :, -1]
kernel_size = np.random.choice([2, 3, 5])
kernel = np.ones(kernel_size, np.uint8)
img_dilation = cv2.dilate(cell_mask, kernel, iterations=1)
x[:, :, -1] = img_dilation
return x


def env_shape_aug(x):
"""
Augment the size of the cells mask in the environment,
by dilating the size of the cell with random kernel
"""
if np.random.random() < 0.5:
cell_mask = x[:, :, -2]
kernel_size = np.random.choice([2, 3, 5])
kernel = np.ones(kernel_size, np.uint8)
img_dilation = cv2.dilate(cell_mask, kernel, iterations=1)
x[:, :, -2] = img_dilation
return x


val_transform = lambda crop_size: torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.CenterCrop((crop_size, crop_size))
])

train_transform = lambda crop_size, shift: torchvision.transforms.Compose([
torchvision.transforms.Lambda(poisson_sampling),
torchvision.transforms.Lambda(cell_shape_aug),
torchvision.transforms.Lambda(env_shape_aug),
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomRotation(degrees=(0, 360)),
Lambda(lambda x: ShiftAugmentation(shift_max=shift, n_size=crop_size)(x) if np.random.random() < 0.5 else x),
torchvision.transforms.CenterCrop((crop_size, crop_size)),
torchvision.transforms.RandomHorizontalFlip(p=0.75),
torchvision.transforms.RandomVerticalFlip(p=0.75),
])
Loading

0 comments on commit 2421176

Please sign in to comment.