This repository contains code for MRNet – The Multi-Task Approach blogpost.
For more details refer to https://stanfordmlgroup.github.io/competitions/mrnet/.
pip install git+https://github.com/ncullen93/torchsample
pip install nibabel
pip install sklearn
pip install pandas
Install other dependencies as per requirement.
-
Clone the repository.
-
Download the dataset (~5.7 GB), and put
train
andvalid
folders along with all the the.csv
files insideimages
folder at root directory.
images/
train/
axial/
sagittal/
coronal/
val/
axial/
sagittal/
coronal/
train-abnormal.csv
train-acl.csv
train-meniscus.csv
valid-abnormal.csv
valid-acl.csv
valid-meniscus.csv
-
Make a new folder called
weights
at root directory, and inside theweights
folder create three more folders namelyacl
,abnormal
andmeniscus
. -
All the hyperparameters are defined in
config.py
file. Feel free to play around those. -
Now finally run the training using
python train.py
. All the logs for tensorboard will be stored in theruns
directory at the root of the project.
The dataset contains MRIs of different people. Each MRI consists of multiple images. Each MRI has data in 3 perpendicular planes. And each plane as variable number of slices.
Each slice is an 256x256
image
For example:
For MRI 1
we will have 3 planes:
Plane 1- with 35 slices
Plane 2- with 34 slices
Place 3 with 35 slices
Each MRI has to be classisifed against 3 diseases.
Major challenge with while selecting the model structure was the inconsistency in the data. Although the image size remains constant , the number of slices per plane are variable within a single MRI and varies across all MRIs.
In the last attempt to MRNet challenge, we used 3 different models for each disease, but instead we can leverage the information that the model learns for each of the disease and make inferencing for other disease better.
We used Hard Parameter sharing in this approach.
We will be using 3 Alexnet pretrained as 3 feature extractors for each of the plane. We then combine these feature extractor layers as an input to a global
fully connected layer for the final classification.