Skip to content
This repository was archived by the owner on Mar 23, 2024. It is now read-only.

Commit 0f1ceb5

Browse files
Merge pull request #24 from brucechou1983/develop
Thread safe generator, multiple baseline models, modified cam
2 parents 9e79c81 + 93bbe55 commit 0f1ceb5

18 files changed

+491
-769
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ densenet121_weights_tf.h5
55
/.idea
66
/experiments
77
/data/Data_Entry_2017.csv
8+
/data/BBox_List_2017.csv
89
venv
9-
config.ini
10+
config.ini
11+
*.pdf

README.md

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,33 @@
11
# ChexNet-Keras
22
This project is a tool to build CheXNet-like models, written in Keras.
33

4-
<img width="450" height="450" src="https://stanfordmlgroup.github.io/projects/chexnet/img/chest-cam.png" alt="CheXNet from Stanford ML Group"/>
4+
<img width="450" height="450" src="cam_example.png" alt="CAM example image"/>
55

6-
## System Requirements
7-
1. Tensorflow_GPU == 1.4 (CUDA 8)
8-
2. Keras == 2.1.4
9-
3. numpy
10-
4. opencv-python (i.e. cv2) ==3.3
11-
5. At least one Nvidia 1080Ti GPU to enable batch_size = 32
12-
13-
### Important notice on CUDA users
14-
If you use >= CUDA 9.1, please modify requirements.txt, such that tensorflow_gpu == 1.5
15-
166
## What is [CheXNet](https://arxiv.org/pdf/1711.05225.pdf)?
17-
ChexNet is a deep learning algorithm that can detect and localize 14 kinds of diseases from chest X-ray images. As described in the paper, a 121-layer densely connected convolutional neural network is trained on ChestX-ray14 dataset, which contains 112,120 frontal view X-ray images from 30,805 unique patients. The result is so good that it surpasses the performance of practicing radiologists.
7+
ChexNet is a deep learning algorithm that can detect and localize 14 kinds of diseases from chest X-ray images. As described in the paper, a 121-layer densely connected convolutional neural network is trained on ChestX-ray14 dataset, which contains 112,120 frontal view X-ray images from 30,805 unique patients. The result is so good that it surpasses the performance of practicing radiologists. If you are new to this project, [Luke Oakden-Rayner's post](https://lukeoakdenrayner.wordpress.com/2017/12/18/the-chestxray14-dataset-problems/) is highly recommended.
188

199
## In this project, you can
2010
1. Train/test a **baseline model** by following the quickstart. You can get a model with performance close to the paper.
21-
2. Modify `multiply` and `use_class_balancing` parameters in `config.ini` to see if you can get better performance.
22-
3. Modify `weights.py` to customize your weights in loss function.
23-
4. Every time you do a new experiment, make sure you modify `output_dir` in `config.ini` otherwise previous training results might be overwritten. For more options check the parameter description in `config.ini`.
11+
2. Run class activation mapping to see the localization of your model.
12+
3. Modify `multiply` parameter in `config.ini` or design your own class weighting to see if you can get better performance.
13+
4. Modify `weights.py` to customize your weights in loss function. If you find something useful, feel free to make that an option and fire a PR.
14+
5. Every time you do a new experiment, make sure you modify `output_dir` in `config.ini` otherwise previous training results might be overwritten. For more options check the parameter description in `config.ini`.
2415

2516
## Quickstart
2617
**Note that currently this project can only be executed in Linux and macOS. You might run into some issues in Windows.**
27-
1. Download **all tar files** and **Data_Entry_2017.csv** of ChestX-ray14 dataset from [NIH dropbox](https://nihcc.app.box.com/v/ChestXray-NIHCC). Put them under `./data` folder and untar all tar files.
28-
2. Download DenseNet-121 ImageNet tensorflow pretrained weights from [DenseNet-Keras](https://drive.google.com/open?id=0Byy2AcGyEVxfSTA4SHJVOHNuTXc). Specify the file path in `config.ini` (field: `base_model_weights_file`)
29-
3. Create & source a new virtualenv. Python >= **3.6** is required.
30-
4. Install dependencies by running `pip3 install -r requirements.txt`.
31-
5. Copy sample_config.ini to config.ini, you may customize `batch_size` and training parameters here. Try to set `patience_reduce_lr` to 2 or 3 in the early training phase. Please note config.ini must exist before training and testing
32-
6. Run `python train.py` to train a new model. If you want to run the training using multiple GPUs, just prepend `CUDA_VISIBLE_DEVICES=0,1,...` to restrict the GPU devices. `nvidia-smi` command will be helpful if you don't know which device are available.
33-
7. Run `python test.py` to test the model.
34-
35-
## CAM
36-
Reference: [Grad-CAM](https://arxiv.org/pdf/1610.02391). CAM image is generated as accumumlated weighted activation before last global average pooling (GAP) layer. It is scaled up to 224\*224 to match original image.
37-
```
38-
python test.py
39-
```
40-
CAM images will be generated into $pwd/imgdir, please make sure you've created the target directory before running test.py
18+
1. Download **all tar files**, **Data_Entry_2017.csv** and **BBox_List_2017.csv** of ChestX-ray14 dataset from [NIH dropbox](https://nihcc.app.box.com/v/ChestXray-NIHCC). Put them under `./data` folder and untar all tar files.
19+
2. Create & source a new virtualenv. Python >= **3.6** is required.
20+
3. Install dependencies by running `pip3 install -r requirements.txt`.
21+
4. Copy sample_config.ini to config.ini, you may customize `batch_size` and training parameters here. Make sure config.ini is configured before you run training or testing
22+
5. Run `python train.py` to train a new model. If you want to run the training using multiple GPUs, just prepend `CUDA_VISIBLE_DEVICES=0,1,...` to restrict the GPU devices. `nvidia-smi` command will be helpful if you don't know which device are available.
23+
6. Run `python test.py` to evaluate your model on the test set.
24+
7. Run `python cam.py` to generate images with class activation mapping overlay and the ground bbox. The ground truth comes from the **BBox_List_2017.csv** file so make sure you have that file in `./data` folder. CAM images will be placed under the output folder.
4125

42-
Guided back-prop is still an enhancement item.
43-
44-
The function is merged into test.py so you wouldn't need test_cam.py anymore. The script will use argmax to plot CAM of the most probable diagnosis only. This version does not support multi-labeled instance at this point.
26+
### Important notice on CUDA users
27+
If you use >= CUDA 9, make sure you set tensorflow_gpu >= 1.5.
4528

4629
## TODO
47-
1. More baseline models
30+
1. Frontend
4831

4932
## Acknowledgement
5033
I would like to thank Pranav Rajpurkar (Stanford ML group) and Xinyu Weng (北京大學) for sharing their experiences on this task. Also I would like to thank Felix Yu for providing DenseNet-Keras source code.

augmenter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from imgaug import augmenters as iaa
2+
3+
augmenter = iaa.Sequential(
4+
[
5+
iaa.Fliplr(0.5),
6+
],
7+
random_order=True,
8+
)

callback.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,14 @@
88
from sklearn.metrics import roc_auc_score
99

1010

11-
def load_generator_data(generator, steps, class_num):
12-
"""
13-
Return some data collected from a generator, use this to ensure all images
14-
are processed by exactly the same steps in the customized ImageDataGenerator.
15-
16-
"""
17-
batches_x = []
18-
batches_y_classes = []
19-
for i in range(class_num):
20-
batches_y_classes.append([])
21-
for i in range(steps):
22-
batch_x, batch_y = next(generator)
23-
batches_x.append(batch_x)
24-
for c, batch_y_class in enumerate(batch_y):
25-
batches_y_classes[c].append(batch_y_class)
26-
return np.concatenate(batches_x, axis=0), [np.concatenate(c, axis=0) for c in batches_y_classes]
27-
28-
2911
class MultipleClassAUROC(Callback):
3012
"""
3113
Monitor mean AUROC and update model
3214
"""
33-
def __init__(self, generator, steps, class_names, weights_path, stats=None):
15+
def __init__(self, sequence, class_names, weights_path, stats=None, workers=1):
3416
super(Callback, self).__init__()
35-
self.generator = generator
36-
self.steps = steps
17+
self.sequence = sequence
18+
self.workers = workers
3719
self.class_names = class_names
3820
self.weights_path = weights_path
3921
self.best_weights_path = os.path.join(
@@ -73,14 +55,14 @@ def on_epoch_end(self, epoch, logs={}):
7355
y_hat shape: (#samples, len(class_names))
7456
y: [(#samples, 1), (#samples, 1) ... (#samples, 1)]
7557
"""
76-
x, y = load_generator_data(self.generator, self.steps, len(self.class_names))
77-
y_hat = self.model.predict(x)
58+
y_hat = self.model.predict_generator(self.sequence, workers=self.workers)
59+
y = self.sequence.get_y_true()
7860

7961
print(f"*** epoch#{epoch + 1} dev auroc ***")
8062
current_auroc = []
8163
for i in range(len(self.class_names)):
8264
try:
83-
score = roc_auc_score(y[i], y_hat[i])
65+
score = roc_auc_score(y[:, i], y_hat[:, i])
8466
except ValueError:
8567
score = 0
8668
self.aurocs[self.class_names[i]].append(score)

cam.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import cv2
2+
import numpy as np
3+
import os
4+
import pandas as pd
5+
from configparser import ConfigParser
6+
from generator import AugmentedImageSequence
7+
from models.keras import ModelFactory
8+
from keras import backend as kb
9+
10+
11+
def get_output_layer(model, layer_name):
12+
# get the symbolic outputs of each "key" layer (we gave them unique names).
13+
layer_dict = dict([(layer.name, layer) for layer in model.layers])
14+
layer = layer_dict[layer_name]
15+
return layer
16+
17+
18+
def create_cam(df_g, output_dir, image_source_dir, model, generator, class_names):
19+
"""
20+
Create a CAM overlay image for the input image
21+
22+
:param df_g: pandas.DataFrame, bboxes on the same image
23+
:param output_dir: str
24+
:param image_source_dir: str
25+
:param model: keras model
26+
:param generator: generator.AugmentedImageSequence
27+
:param class_names: list of str
28+
"""
29+
file_name = df_g["file_name"]
30+
print(f"process image: {file_name}")
31+
32+
# draw bbox with labels
33+
img_ori = cv2.imread(filename=os.path.join(image_source_dir, file_name))
34+
35+
label = df_g["label"]
36+
if label == "Infiltrate":
37+
label = "Infiltration"
38+
index = class_names.index(label)
39+
40+
output_path = os.path.join(output_dir, f"{label}.{file_name}")
41+
42+
img_transformed = generator.load_image(file_name)
43+
44+
# CAM overlay
45+
# Get the 512 input weights to the softmax.
46+
class_weights = model.layers[-1].get_weights()[0]
47+
final_conv_layer = get_output_layer(model, "bn")
48+
get_output = kb.function([model.layers[0].input], [final_conv_layer.output, model.layers[-1].output])
49+
[conv_outputs, predictions] = get_output([np.array([img_transformed])])
50+
conv_outputs = conv_outputs[0, :, :, :]
51+
52+
# Create the class activation map.
53+
cam = np.zeros(dtype=np.float32, shape=(conv_outputs.shape[:2]))
54+
for i, w in enumerate(class_weights[index]):
55+
cam += w * conv_outputs[:, :, i]
56+
# print(f"predictions: {predictions}")
57+
cam /= np.max(cam)
58+
cam = cv2.resize(cam, img_ori.shape[:2])
59+
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
60+
heatmap[np.where(cam < 0.2)] = 0
61+
img = heatmap * 0.5 + img_ori
62+
63+
# add label & rectangle
64+
# ratio = output dimension / 1024
65+
ratio = 1
66+
x1 = int(df_g["x"] * ratio)
67+
y1 = int(df_g["y"] * ratio)
68+
x2 = int((df_g["x"] + df_g["w"]) * ratio)
69+
y2 = int((df_g["y"] + df_g["h"]) * ratio)
70+
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
71+
cv2.putText(img, text=label, org=(5, 20), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
72+
fontScale=0.8, color=(0, 0, 255), thickness=1)
73+
cv2.imwrite(output_path, img)
74+
75+
76+
def main():
77+
# parser config
78+
config_file = "./config.ini"
79+
cp = ConfigParser()
80+
cp.read(config_file)
81+
82+
# default config
83+
output_dir = cp["DEFAULT"].get("output_dir")
84+
base_model_name = cp["DEFAULT"].get("base_model_name")
85+
class_names = cp["DEFAULT"].get("class_names").split(",")
86+
image_source_dir = cp["DEFAULT"].get("image_source_dir")
87+
image_dimension = cp["TRAIN"].getint("image_dimension")
88+
89+
# parse weights file path
90+
output_weights_name = cp["TRAIN"].get("output_weights_name")
91+
weights_path = os.path.join(output_dir, output_weights_name)
92+
best_weights_path = os.path.join(output_dir, f"best_{output_weights_name}")
93+
94+
# CAM config
95+
bbox_list_file = cp["CAM"].get("bbox_list_file")
96+
use_best_weights = cp["CAM"].getboolean("use_best_weights")
97+
98+
print("** load model **")
99+
if use_best_weights:
100+
print("** use best weights **")
101+
model_weights_path = best_weights_path
102+
else:
103+
print("** use last weights **")
104+
model_weights_path = weights_path
105+
model_factory = ModelFactory()
106+
model = model_factory.get_model(
107+
class_names,
108+
model_name=base_model_name,
109+
use_base_weights=False,
110+
weights_path=model_weights_path)
111+
112+
print("read bbox list file")
113+
df_images = pd.read_csv(bbox_list_file, header=None, skiprows=1)
114+
df_images.columns = ["file_name", "label", "x", "y", "w", "h"]
115+
116+
print("create a generator for loading transformed images")
117+
cam_sequence = AugmentedImageSequence(
118+
dataset_csv_file=os.path.join(output_dir, "test.csv"),
119+
class_names=class_names,
120+
source_image_dir=image_source_dir,
121+
batch_size=1,
122+
target_size=(image_dimension, image_dimension),
123+
augmenter=None,
124+
steps=1,
125+
shuffle_on_epoch_end=False,
126+
)
127+
128+
image_output_dir = os.path.join(output_dir, "cam")
129+
if not os.path.isdir(image_output_dir):
130+
os.makedirs(image_output_dir)
131+
132+
print("create CAM")
133+
df_images.apply(
134+
lambda g: create_cam(
135+
df_g=g,
136+
output_dir=image_output_dir,
137+
image_source_dir=image_source_dir,
138+
model=model,
139+
generator=cam_sequence,
140+
class_names=class_names,
141+
),
142+
axis=1,
143+
)
144+
145+
146+
if __name__ == "__main__":
147+
main()

cam_example.png

770 KB
Loading

0 commit comments

Comments
 (0)