Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/3D_image_classification.py
3507 views
1
"""
2
Title: 3D image classification from CT scans
3
Author: [Hasib Zunair](https://twitter.com/hasibzunair)
4
Date created: 2020/09/23
5
Last modified: 2024/01/11
6
Description: Train a 3D convolutional neural network to predict presence of pneumonia.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example will show the steps needed to build a 3D convolutional neural network (CNN)
14
to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are
15
commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D
16
equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan),
17
3D CNNs are a powerful model for learning representations for volumetric data.
18
19
## References
20
21
- [A survey on Deep Learning Advances on Different 3D DataRepresentations](https://arxiv.org/abs/1808.01462)
22
- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)
23
- [FusionNet: 3D Object Classification Using MultipleData Representations](https://arxiv.org/abs/1607.05695)
24
- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction](https://arxiv.org/abs/2007.13224)
25
"""
26
"""
27
## Setup
28
"""
29
30
import os
31
import zipfile
32
import numpy as np
33
import tensorflow as tf # for data preprocessing
34
35
import keras
36
from keras import layers
37
38
"""
39
## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings
40
41
In this example, we use a subset of the
42
[MosMedData: Chest CT Scans with COVID-19 Related Findings](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1).
43
This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.
44
45
We will be using the associated radiological findings of the CT scans as labels to build
46
a classifier to predict presence of viral pneumonia.
47
Hence, the task is a binary classification problem.
48
"""
49
50
# Download url of normal CT scans.
51
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"
52
filename = os.path.join(os.getcwd(), "CT-0.zip")
53
keras.utils.get_file(filename, url)
54
55
# Download url of abnormal CT scans.
56
url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
57
filename = os.path.join(os.getcwd(), "CT-23.zip")
58
keras.utils.get_file(filename, url)
59
60
# Make a directory to store the data.
61
os.makedirs("MosMedData")
62
63
# Unzip data in the newly created directory.
64
with zipfile.ZipFile("CT-0.zip", "r") as z_fp:
65
z_fp.extractall("./MosMedData/")
66
67
with zipfile.ZipFile("CT-23.zip", "r") as z_fp:
68
z_fp.extractall("./MosMedData/")
69
70
"""
71
## Loading data and preprocessing
72
73
The files are provided in Nifti format with the extension .nii. To read the
74
scans, we use the `nibabel` package.
75
You can install the package via `pip install nibabel`. CT scans store raw voxel
76
intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset.
77
Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold
78
between -1000 and 400 is commonly used to normalize CT scans.
79
80
To process the data, we do the following:
81
82
* We first rotate the volumes by 90 degrees, so the orientation is fixed
83
* We scale the HU values to be between 0 and 1.
84
* We resize width, height and depth.
85
86
Here we define several helper functions to process the data. These functions
87
will be used when building training and validation datasets.
88
"""
89
90
91
import nibabel as nib
92
93
from scipy import ndimage
94
95
96
def read_nifti_file(filepath):
97
"""Read and load volume"""
98
# Read file
99
scan = nib.load(filepath)
100
# Get raw data
101
scan = scan.get_fdata()
102
return scan
103
104
105
def normalize(volume):
106
"""Normalize the volume"""
107
min = -1000
108
max = 400
109
volume[volume < min] = min
110
volume[volume > max] = max
111
volume = (volume - min) / (max - min)
112
volume = volume.astype("float32")
113
return volume
114
115
116
def resize_volume(img):
117
"""Resize across z-axis"""
118
# Set the desired depth
119
desired_depth = 64
120
desired_width = 128
121
desired_height = 128
122
# Get current depth
123
current_depth = img.shape[-1]
124
current_width = img.shape[0]
125
current_height = img.shape[1]
126
# Compute depth factor
127
depth = current_depth / desired_depth
128
width = current_width / desired_width
129
height = current_height / desired_height
130
depth_factor = 1 / depth
131
width_factor = 1 / width
132
height_factor = 1 / height
133
# Rotate
134
img = ndimage.rotate(img, 90, reshape=False)
135
# Resize across z-axis
136
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
137
return img
138
139
140
def process_scan(path):
141
"""Read and resize volume"""
142
# Read scan
143
volume = read_nifti_file(path)
144
# Normalize
145
volume = normalize(volume)
146
# Resize width, height and depth
147
volume = resize_volume(volume)
148
return volume
149
150
151
"""
152
Let's read the paths of the CT scans from the class directories.
153
"""
154
155
# Folder "CT-0" consist of CT scans having normal lung tissue,
156
# no CT-signs of viral pneumonia.
157
normal_scan_paths = [
158
os.path.join(os.getcwd(), "MosMedData/CT-0", x)
159
for x in os.listdir("MosMedData/CT-0")
160
]
161
# Folder "CT-23" consist of CT scans having several ground-glass opacifications,
162
# involvement of lung parenchyma.
163
abnormal_scan_paths = [
164
os.path.join(os.getcwd(), "MosMedData/CT-23", x)
165
for x in os.listdir("MosMedData/CT-23")
166
]
167
168
print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
169
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
170
171
172
"""
173
## Build train and validation datasets
174
Read the scans from the class directories and assign labels. Downsample the scans to have
175
shape of 128x128x64. Rescale the raw HU values to the range 0 to 1.
176
Lastly, split the dataset into train and validation subsets.
177
"""
178
179
# Read and process the scans.
180
# Each scan is resized across height, width, and depth and rescaled.
181
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
182
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
183
184
# For the CT scans having presence of viral pneumonia
185
# assign 1, for the normal ones assign 0.
186
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
187
normal_labels = np.array([0 for _ in range(len(normal_scans))])
188
189
# Split data in the ratio 70-30 for training and validation.
190
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
191
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
192
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
193
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
194
print(
195
"Number of samples in train and validation are %d and %d."
196
% (x_train.shape[0], x_val.shape[0])
197
)
198
199
"""
200
## Data augmentation
201
202
The CT scans also augmented by rotating at random angles during training. Since
203
the data is stored in rank-3 tensors of shape `(samples, height, width, depth)`,
204
we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on
205
the data. The new shape is thus `(samples, height, width, depth, 1)`. There are
206
different kinds of preprocessing and augmentation techniques out there,
207
this example shows a few simple ones to get started.
208
"""
209
210
import random
211
212
from scipy import ndimage
213
214
215
def rotate(volume):
216
"""Rotate the volume by a few degrees"""
217
218
def scipy_rotate(volume):
219
# define some rotation angles
220
angles = [-20, -10, -5, 5, 10, 20]
221
# pick angles at random
222
angle = random.choice(angles)
223
# rotate volume
224
volume = ndimage.rotate(volume, angle, reshape=False)
225
volume[volume < 0] = 0
226
volume[volume > 1] = 1
227
return volume
228
229
augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
230
return augmented_volume
231
232
233
def train_preprocessing(volume, label):
234
"""Process training data by rotating and adding a channel."""
235
# Rotate volume
236
volume = rotate(volume)
237
volume = tf.expand_dims(volume, axis=3)
238
return volume, label
239
240
241
def validation_preprocessing(volume, label):
242
"""Process validation data by only adding a channel."""
243
volume = tf.expand_dims(volume, axis=3)
244
return volume, label
245
246
247
"""
248
While defining the train and validation data loader, the training data is passed through
249
and augmentation function which randomly rotates volume at different angles. Note that both
250
training and validation data are already rescaled to have values between 0 and 1.
251
"""
252
253
# Define data loaders.
254
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
255
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
256
257
batch_size = 2
258
# Augment the on the fly during training.
259
train_dataset = (
260
train_loader.shuffle(len(x_train))
261
.map(train_preprocessing)
262
.batch(batch_size)
263
.prefetch(2)
264
)
265
# Only rescale.
266
validation_dataset = (
267
validation_loader.shuffle(len(x_val))
268
.map(validation_preprocessing)
269
.batch(batch_size)
270
.prefetch(2)
271
)
272
273
"""
274
Visualize an augmented CT scan.
275
"""
276
277
import matplotlib.pyplot as plt
278
279
data = train_dataset.take(1)
280
images, labels = list(data)[0]
281
images = images.numpy()
282
image = images[0]
283
print("Dimension of the CT scan is:", image.shape)
284
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
285
286
287
"""
288
Since a CT scan has many slices, let's visualize a montage of the slices.
289
"""
290
291
292
def plot_slices(num_rows, num_columns, width, height, data):
293
"""Plot a montage of 20 CT slices"""
294
data = np.rot90(np.array(data))
295
data = np.transpose(data)
296
data = np.reshape(data, (num_rows, num_columns, width, height))
297
rows_data, columns_data = data.shape[0], data.shape[1]
298
heights = [slc[0].shape[0] for slc in data]
299
widths = [slc.shape[1] for slc in data[0]]
300
fig_width = 12.0
301
fig_height = fig_width * sum(heights) / sum(widths)
302
f, axarr = plt.subplots(
303
rows_data,
304
columns_data,
305
figsize=(fig_width, fig_height),
306
gridspec_kw={"height_ratios": heights},
307
)
308
for i in range(rows_data):
309
for j in range(columns_data):
310
axarr[i, j].imshow(data[i][j], cmap="gray")
311
axarr[i, j].axis("off")
312
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
313
plt.show()
314
315
316
# Visualize montage of slices.
317
# 4 rows and 10 columns for 100 slices of the CT scan.
318
plot_slices(4, 10, 128, 128, image[:, :, :40])
319
320
"""
321
## Define a 3D convolutional neural network
322
323
To make the model easier to understand, we structure it into blocks.
324
The architecture of the 3D CNN used in this example
325
is based on [this paper](https://arxiv.org/abs/2007.13224).
326
"""
327
328
329
def get_model(width=128, height=128, depth=64):
330
"""Build a 3D convolutional neural network model."""
331
332
inputs = keras.Input((width, height, depth, 1))
333
334
x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
335
x = layers.MaxPool3D(pool_size=2)(x)
336
x = layers.BatchNormalization()(x)
337
338
x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
339
x = layers.MaxPool3D(pool_size=2)(x)
340
x = layers.BatchNormalization()(x)
341
342
x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
343
x = layers.MaxPool3D(pool_size=2)(x)
344
x = layers.BatchNormalization()(x)
345
346
x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
347
x = layers.MaxPool3D(pool_size=2)(x)
348
x = layers.BatchNormalization()(x)
349
350
x = layers.GlobalAveragePooling3D()(x)
351
x = layers.Dense(units=512, activation="relu")(x)
352
x = layers.Dropout(0.3)(x)
353
354
outputs = layers.Dense(units=1, activation="sigmoid")(x)
355
356
# Define the model.
357
model = keras.Model(inputs, outputs, name="3dcnn")
358
return model
359
360
361
# Build model.
362
model = get_model(width=128, height=128, depth=64)
363
model.summary()
364
365
"""
366
## Train model
367
"""
368
369
# Compile model.
370
initial_learning_rate = 0.0001
371
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
372
initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
373
)
374
model.compile(
375
loss="binary_crossentropy",
376
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
377
metrics=["acc"],
378
run_eagerly=True,
379
)
380
381
# Define callbacks.
382
checkpoint_cb = keras.callbacks.ModelCheckpoint(
383
"3d_image_classification.keras", save_best_only=True
384
)
385
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
386
387
# Train the model, doing validation at the end of each epoch
388
epochs = 100
389
model.fit(
390
train_dataset,
391
validation_data=validation_dataset,
392
epochs=epochs,
393
shuffle=True,
394
verbose=2,
395
callbacks=[checkpoint_cb, early_stopping_cb],
396
)
397
398
"""
399
It is important to note that the number of samples is very small (only 200) and we don't
400
specify a random seed. As such, you can expect significant variance in the results. The full dataset
401
which consists of over 1000 CT scans can be found [here](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1). Using the full
402
dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification
403
performance is observed in both cases.
404
"""
405
406
"""
407
## Visualizing model performance
408
409
Here the model accuracy and loss for the training and the validation sets are plotted.
410
Since the validation set is class-balanced, accuracy provides an unbiased representation
411
of the model's performance.
412
"""
413
414
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
415
ax = ax.ravel()
416
417
for i, metric in enumerate(["acc", "loss"]):
418
ax[i].plot(model.history.history[metric])
419
ax[i].plot(model.history.history["val_" + metric])
420
ax[i].set_title("Model {}".format(metric))
421
ax[i].set_xlabel("epochs")
422
ax[i].set_ylabel(metric)
423
ax[i].legend(["train", "val"])
424
425
"""
426
## Make predictions on a single CT scan
427
"""
428
429
# Load best weights.
430
model.load_weights("3d_image_classification.keras")
431
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
432
scores = [1 - prediction[0], prediction[0]]
433
434
class_names = ["normal", "abnormal"]
435
for score, name in zip(scores, class_names):
436
print(
437
"This model is %.2f percent confident that CT scan is %s"
438
% ((100 * score), name)
439
)
440
441