Skip to content

add predict_from_video to cell_nucleus_segmentor #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@
### Installation
```python
pip install devolearn

# or you can build from the source:
pip install git+https://github.com/DevoLearn/devolearn
```
### Example notebooks
<p align="center">
<img src = "https://raw.githubusercontent.com/DevoLearn/data-science-demos/master/Networks/nodes_matrix_long_smooth.gif" width = "40%">
<img src = "https://raw.githubusercontent.com/DevoLearn/data-science-demos/master/Networks/3d_node_map.gif" width = "40%">
<img src = "https://raw.githubusercontent.com/DevoLearn/data-science-demos/master/Networks/3d_node_map.gif" width = "40%">
</p>

* [Extracting centroid maps and making 3d centroid models](https://nbviewer.jupyter.org/github/DevoLearn/data-science-demos/blob/master/Networks/experiments_with_devolearn_node_maps.ipynb)

### Segmenting the Cell Membrane in C. elegans embryo
### Segmenting the Cell Membrane in C. elegans embryo
<p align="center">
<img src = "https://raw.githubusercontent.com/DevoLearn/devolearn/master/images/pred_centroids.gif" width = "80%">
</p>
Expand All @@ -53,7 +56,7 @@ plt.imshow(seg_pred)
plt.show()
```

* Running the model on a video and saving the predictions into a folder
* Running the model on a video and saving the predictions into a folder
```python
filenames = segmentor.predict_from_video(video_path = "sample_data/videos/seg_sample.mov", centroid_mode = False, save_folder = "preds")
```
Expand All @@ -72,7 +75,7 @@ df = segmentor.predict_from_video(video_path = "sample_data/videos/seg_sample.mo
df.to_csv("centroids.csv")
```

### Segmenting the Cell Nucleus in C. elegans embryo
### Segmenting the Cell Nucleus in C. elegans embryo
<p align="center">
<img src = "https://github.com/Mainakdeb/devolearn/blob/master/images/nucleus_segmentation.gif" width = "60%">
</p>
Expand Down Expand Up @@ -105,7 +108,7 @@ generator = embryo_generator_model()

* Generating a picture and viewing it with [matplotlib](https://matplotlib.org/)
```python
gen_image = generator.generate()
gen_image = generator.generate()
plt.imshow(gen_image)
plt.show()

Expand All @@ -125,7 +128,7 @@ generator.generate_n_images(n = 5, foldername= "generated_images", image_size= (
<img src = "https://raw.githubusercontent.com/devoworm/GSoC-2020/master/Pre-trained%20Models%20(DevLearning)/images/resnet_preds_with_input.gif" width = "60%">
</p>

* Importing the population model for inferences
* Importing the population model for inferences
```python
from devolearn import lineage_population_model
```
Expand Down Expand Up @@ -155,8 +158,8 @@ plot.show()
| **Model** | **Data source** |
|-------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Segmenting the cell membrane in C. elegans embryo | [3DMMS: robust 3D Membrane Morphological Segmentation of C. elegans embryo](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-2720-x#Abs1/) |
| Segmenting the nucleus in C. elegans embryo | [C. elegans Cell-Tracking-Challenge dataset](http://celltrackingchallenge.net/3d-datasets/)
| Cell lineage population prediction + embryo GAN | [EPIC dataset](https://epic.gs.washington.edu/)
| Segmenting the nucleus in C. elegans embryo | [C. elegans Cell-Tracking-Challenge dataset](http://celltrackingchallenge.net/3d-datasets/)
| Cell lineage population prediction + embryo GAN | [EPIC dataset](https://epic.gs.washington.edu/)

## Authors/maintainers:
* [Mayukh Deb](https://twitter.com/mayukh091)
Expand Down
132 changes: 120 additions & 12 deletions devolearn/cell_nucleus_segmentor/cell_nucleus_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,51 @@
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore")

from ..base_inference_engine import InferenceEngine

def generate_centroid_image(thresh):
"""Used when centroid_mode is set to True

Args:
thresh (np.array): 2d numpy array that is returned from the segmentation model

Returns:
np.array : image containing the contours and their respective centroids
list : list of all centroids for the given image as [(x1,y1), (x2,y2)...]
"""

thresh = cv2.blur(thresh, (5,5))
thresh = thresh.astype(np.uint8)
centroid_image = np.zeros(thresh.shape)
cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
centroids = []
for c in cnts:
try:
# compute the center of the contour
M = cv2.moments(c)
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
# draw the contour and center of the shape on the image
cv2.drawContours(centroid_image, [c], -1, (255, 255, 255), 2)
cv2.circle(centroid_image, (cX, cY), 2, (255, 255, 255), -1)
centroids.append((cX, cY))
except:
pass

return centroid_image, centroids

class cell_nucleus_segmentor(InferenceEngine):
def __init__(self, device = "cpu"):
"""Segments the c. elegans embryo from images/videos,
"""Segments the c. elegans embryo from images/videos,
depends on segmentation-models-pytorch for the model backbone

Args:
device (str, optional): set to "cuda", runs operations on gpu and set to "cpu", runs operations on cpu. Defaults to "cpu".
"""

self.device = device
self.ENCODER = 'resnet18'
self.ENCODER_WEIGHTS = 'imagenet'
Expand All @@ -44,11 +75,11 @@ def __init__(self, device = "cpu"):
# print("at : ", os.path.dirname(__file__))

self.model = smp.FPN(
encoder_name= self.ENCODER,
encoder_weights= self.ENCODER_WEIGHTS,
classes=len(self.CLASSES),
encoder_name= self.ENCODER,
encoder_weights= self.ENCODER_WEIGHTS,
classes=len(self.CLASSES),
activation= self.ACTIVATION,
in_channels = self.in_channels
in_channels = self.in_channels
)

self.download_checkpoint()
Expand All @@ -65,23 +96,23 @@ def __init__(self, device = "cpu"):
def download_checkpoint(self):
try:
# print("model already downloaded, loading model...")
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)
except:
print("model not found, downloading from:", self.model_url)
if os.path.isdir(self.model_dir) == False:
os.mkdir(self.model_dir)
filename = wget.download(self.model_url, out= self.model_dir)
# print(filename)
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)
self.model = torch.load(self.model_dir + "/" + self.model_name, map_location= self.device)

def preprocess(self, image_grayscale_numpy):

tensor = self.mini_transform(image_grayscale_numpy).unsqueeze(0).to(self.device)
return tensor

def predict(self, image_path, pred_size = (350,250)):
def predict(self, image_path, pred_size = (350,250), centroid_mode = False):
"""
Loads an image from image_path and converts it to grayscale,
Loads an image from image_path and converts it to grayscale,
then passes it through the model and returns centroids of the segmented features.
reference{
https://github.com/DevoLearn/devolearn#segmenting-the-c-elegans-embryo
Expand All @@ -106,4 +137,81 @@ def predict(self, image_path, pred_size = (350,250)):

res = self.model(tensor).detach().cpu().numpy()[0][0]
res = cv2.resize(res,pred_size)
return res
if centroid_mode == False:
return res
else:
centroid_image, centroids = generate_centroid_image(res)
return centroid_image, centroids

def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "preds", centroid_mode = False, notebook_mode = False):
"""Splits a video from video_path into frames and passes the
frames through the model for predictions. Saves predicted images in save_folder.
And optionally saves all the centroid predictions into a pandas.DataFrame.

Args:
video_path (str): path to the video file.
pred_size (tuple, optional): size of output image,(width,height). Defaults to (350,250).
save_folder (str, optional): path to folder to be saved in. Defaults to "preds".
centroid_mode (bool, optional): set to true to return both the segmented image and the list of centroids. Defaults to False.
notebook_mode (bool, optional): toogle between script(False) and notebook(True), for better user interface. Defaults to False.

Returns:
centroid_mode set to True:
pd.DataFrame : containing file name and their centriods
centroid_mode set to False:
list : list containing the names of the entries in the save_folder directory
"""

vidObj = cv2.VideoCapture(video_path)
success = 1
images = deque()
count = 0

if centroid_mode == True:
filenames_centroids = []

while success:
success, image = vidObj.read()

try:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
images.append(image)

except:
# print("skipped possible corrupt frame number : ", count)
pass
count += 1

if os.path.isdir(save_folder) == False:
os.mkdir(save_folder)

if notebook_mode == True:
for i in tqdm_notebook(range(len(images)), desc = "saving predictions: "):
save_name = save_folder + "/" + str(i) + ".jpg"
tensor = self.mini_transform(images[i]).unsqueeze(0).to(self.device)
res = self.model(tensor).detach().cpu().numpy()[0][0]

if centroid_mode == True:
res, centroids = generate_centroid_image(res)
filenames_centroids.append([save_name, centroids])

res = cv2.resize(res,pred_size)
cv2.imwrite(save_name, res*255)
else :
for i in tqdm(range(len(images)), desc = "saving predictions: "):
save_name = save_folder + "/" + str(i) + ".jpg"
tensor = self.mini_transform(images[i]).unsqueeze(0).to(self.device)
res = self.model(tensor).detach().cpu().numpy()[0][0]

if centroid_mode == True:
res, centroids = generate_centroid_image(res)
filenames_centroids.append([save_name, centroids])

res = cv2.resize(res,pred_size)
cv2.imwrite(save_name, res*255)

if centroid_mode == True:
df = pd.DataFrame(filenames_centroids, columns = ["filenames", "centroids"])
return df
else:
return os.listdir(save_folder)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ sklearn==0.0
threadpoolctl==2.1.0
tifffile
timm==0.3.2
torch>=1.7.0
torch==1.7.0
torchvision>=0.8.1
tqdm==4.56.0
typing-extensions>=3.7.4.3
wget==3.2
pytest-cov
pytest-cov