Path: blob/master/examples/vision/3D_image_classification.py
3507 views
"""1Title: 3D image classification from CT scans2Author: [Hasib Zunair](https://twitter.com/hasibzunair)3Date created: 2020/09/234Last modified: 2024/01/115Description: Train a 3D convolutional neural network to predict presence of pneumonia.6Accelerator: GPU7"""89"""10## Introduction1112This example will show the steps needed to build a 3D convolutional neural network (CNN)13to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are14commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D15equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan),163D CNNs are a powerful model for learning representations for volumetric data.1718## References1920- [A survey on Deep Learning Advances on Different 3D DataRepresentations](https://arxiv.org/abs/1808.01462)21- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)22- [FusionNet: 3D Object Classification Using MultipleData Representations](https://arxiv.org/abs/1607.05695)23- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction](https://arxiv.org/abs/2007.13224)24"""25"""26## Setup27"""2829import os30import zipfile31import numpy as np32import tensorflow as tf # for data preprocessing3334import keras35from keras import layers3637"""38## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings3940In this example, we use a subset of the41[MosMedData: Chest CT Scans with COVID-19 Related Findings](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1).42This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.4344We will be using the associated radiological findings of the CT scans as labels to build45a classifier to predict presence of viral pneumonia.46Hence, the task is a binary classification problem.47"""4849# Download url of normal CT scans.50url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"51filename = os.path.join(os.getcwd(), "CT-0.zip")52keras.utils.get_file(filename, url)5354# Download url of abnormal CT scans.55url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"56filename = os.path.join(os.getcwd(), "CT-23.zip")57keras.utils.get_file(filename, url)5859# Make a directory to store the data.60os.makedirs("MosMedData")6162# Unzip data in the newly created directory.63with zipfile.ZipFile("CT-0.zip", "r") as z_fp:64z_fp.extractall("./MosMedData/")6566with zipfile.ZipFile("CT-23.zip", "r") as z_fp:67z_fp.extractall("./MosMedData/")6869"""70## Loading data and preprocessing7172The files are provided in Nifti format with the extension .nii. To read the73scans, we use the `nibabel` package.74You can install the package via `pip install nibabel`. CT scans store raw voxel75intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset.76Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold77between -1000 and 400 is commonly used to normalize CT scans.7879To process the data, we do the following:8081* We first rotate the volumes by 90 degrees, so the orientation is fixed82* We scale the HU values to be between 0 and 1.83* We resize width, height and depth.8485Here we define several helper functions to process the data. These functions86will be used when building training and validation datasets.87"""888990import nibabel as nib9192from scipy import ndimage939495def read_nifti_file(filepath):96"""Read and load volume"""97# Read file98scan = nib.load(filepath)99# Get raw data100scan = scan.get_fdata()101return scan102103104def normalize(volume):105"""Normalize the volume"""106min = -1000107max = 400108volume[volume < min] = min109volume[volume > max] = max110volume = (volume - min) / (max - min)111volume = volume.astype("float32")112return volume113114115def resize_volume(img):116"""Resize across z-axis"""117# Set the desired depth118desired_depth = 64119desired_width = 128120desired_height = 128121# Get current depth122current_depth = img.shape[-1]123current_width = img.shape[0]124current_height = img.shape[1]125# Compute depth factor126depth = current_depth / desired_depth127width = current_width / desired_width128height = current_height / desired_height129depth_factor = 1 / depth130width_factor = 1 / width131height_factor = 1 / height132# Rotate133img = ndimage.rotate(img, 90, reshape=False)134# Resize across z-axis135img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)136return img137138139def process_scan(path):140"""Read and resize volume"""141# Read scan142volume = read_nifti_file(path)143# Normalize144volume = normalize(volume)145# Resize width, height and depth146volume = resize_volume(volume)147return volume148149150"""151Let's read the paths of the CT scans from the class directories.152"""153154# Folder "CT-0" consist of CT scans having normal lung tissue,155# no CT-signs of viral pneumonia.156normal_scan_paths = [157os.path.join(os.getcwd(), "MosMedData/CT-0", x)158for x in os.listdir("MosMedData/CT-0")159]160# Folder "CT-23" consist of CT scans having several ground-glass opacifications,161# involvement of lung parenchyma.162abnormal_scan_paths = [163os.path.join(os.getcwd(), "MosMedData/CT-23", x)164for x in os.listdir("MosMedData/CT-23")165]166167print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))168print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))169170171"""172## Build train and validation datasets173Read the scans from the class directories and assign labels. Downsample the scans to have174shape of 128x128x64. Rescale the raw HU values to the range 0 to 1.175Lastly, split the dataset into train and validation subsets.176"""177178# Read and process the scans.179# Each scan is resized across height, width, and depth and rescaled.180abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])181normal_scans = np.array([process_scan(path) for path in normal_scan_paths])182183# For the CT scans having presence of viral pneumonia184# assign 1, for the normal ones assign 0.185abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])186normal_labels = np.array([0 for _ in range(len(normal_scans))])187188# Split data in the ratio 70-30 for training and validation.189x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)190y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)191x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)192y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)193print(194"Number of samples in train and validation are %d and %d."195% (x_train.shape[0], x_val.shape[0])196)197198"""199## Data augmentation200201The CT scans also augmented by rotating at random angles during training. Since202the data is stored in rank-3 tensors of shape `(samples, height, width, depth)`,203we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on204the data. The new shape is thus `(samples, height, width, depth, 1)`. There are205different kinds of preprocessing and augmentation techniques out there,206this example shows a few simple ones to get started.207"""208209import random210211from scipy import ndimage212213214def rotate(volume):215"""Rotate the volume by a few degrees"""216217def scipy_rotate(volume):218# define some rotation angles219angles = [-20, -10, -5, 5, 10, 20]220# pick angles at random221angle = random.choice(angles)222# rotate volume223volume = ndimage.rotate(volume, angle, reshape=False)224volume[volume < 0] = 0225volume[volume > 1] = 1226return volume227228augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)229return augmented_volume230231232def train_preprocessing(volume, label):233"""Process training data by rotating and adding a channel."""234# Rotate volume235volume = rotate(volume)236volume = tf.expand_dims(volume, axis=3)237return volume, label238239240def validation_preprocessing(volume, label):241"""Process validation data by only adding a channel."""242volume = tf.expand_dims(volume, axis=3)243return volume, label244245246"""247While defining the train and validation data loader, the training data is passed through248and augmentation function which randomly rotates volume at different angles. Note that both249training and validation data are already rescaled to have values between 0 and 1.250"""251252# Define data loaders.253train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))254validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))255256batch_size = 2257# Augment the on the fly during training.258train_dataset = (259train_loader.shuffle(len(x_train))260.map(train_preprocessing)261.batch(batch_size)262.prefetch(2)263)264# Only rescale.265validation_dataset = (266validation_loader.shuffle(len(x_val))267.map(validation_preprocessing)268.batch(batch_size)269.prefetch(2)270)271272"""273Visualize an augmented CT scan.274"""275276import matplotlib.pyplot as plt277278data = train_dataset.take(1)279images, labels = list(data)[0]280images = images.numpy()281image = images[0]282print("Dimension of the CT scan is:", image.shape)283plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")284285286"""287Since a CT scan has many slices, let's visualize a montage of the slices.288"""289290291def plot_slices(num_rows, num_columns, width, height, data):292"""Plot a montage of 20 CT slices"""293data = np.rot90(np.array(data))294data = np.transpose(data)295data = np.reshape(data, (num_rows, num_columns, width, height))296rows_data, columns_data = data.shape[0], data.shape[1]297heights = [slc[0].shape[0] for slc in data]298widths = [slc.shape[1] for slc in data[0]]299fig_width = 12.0300fig_height = fig_width * sum(heights) / sum(widths)301f, axarr = plt.subplots(302rows_data,303columns_data,304figsize=(fig_width, fig_height),305gridspec_kw={"height_ratios": heights},306)307for i in range(rows_data):308for j in range(columns_data):309axarr[i, j].imshow(data[i][j], cmap="gray")310axarr[i, j].axis("off")311plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)312plt.show()313314315# Visualize montage of slices.316# 4 rows and 10 columns for 100 slices of the CT scan.317plot_slices(4, 10, 128, 128, image[:, :, :40])318319"""320## Define a 3D convolutional neural network321322To make the model easier to understand, we structure it into blocks.323The architecture of the 3D CNN used in this example324is based on [this paper](https://arxiv.org/abs/2007.13224).325"""326327328def get_model(width=128, height=128, depth=64):329"""Build a 3D convolutional neural network model."""330331inputs = keras.Input((width, height, depth, 1))332333x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)334x = layers.MaxPool3D(pool_size=2)(x)335x = layers.BatchNormalization()(x)336337x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)338x = layers.MaxPool3D(pool_size=2)(x)339x = layers.BatchNormalization()(x)340341x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)342x = layers.MaxPool3D(pool_size=2)(x)343x = layers.BatchNormalization()(x)344345x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)346x = layers.MaxPool3D(pool_size=2)(x)347x = layers.BatchNormalization()(x)348349x = layers.GlobalAveragePooling3D()(x)350x = layers.Dense(units=512, activation="relu")(x)351x = layers.Dropout(0.3)(x)352353outputs = layers.Dense(units=1, activation="sigmoid")(x)354355# Define the model.356model = keras.Model(inputs, outputs, name="3dcnn")357return model358359360# Build model.361model = get_model(width=128, height=128, depth=64)362model.summary()363364"""365## Train model366"""367368# Compile model.369initial_learning_rate = 0.0001370lr_schedule = keras.optimizers.schedules.ExponentialDecay(371initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True372)373model.compile(374loss="binary_crossentropy",375optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),376metrics=["acc"],377run_eagerly=True,378)379380# Define callbacks.381checkpoint_cb = keras.callbacks.ModelCheckpoint(382"3d_image_classification.keras", save_best_only=True383)384early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)385386# Train the model, doing validation at the end of each epoch387epochs = 100388model.fit(389train_dataset,390validation_data=validation_dataset,391epochs=epochs,392shuffle=True,393verbose=2,394callbacks=[checkpoint_cb, early_stopping_cb],395)396397"""398It is important to note that the number of samples is very small (only 200) and we don't399specify a random seed. As such, you can expect significant variance in the results. The full dataset400which consists of over 1000 CT scans can be found [here](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1). Using the full401dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification402performance is observed in both cases.403"""404405"""406## Visualizing model performance407408Here the model accuracy and loss for the training and the validation sets are plotted.409Since the validation set is class-balanced, accuracy provides an unbiased representation410of the model's performance.411"""412413fig, ax = plt.subplots(1, 2, figsize=(20, 3))414ax = ax.ravel()415416for i, metric in enumerate(["acc", "loss"]):417ax[i].plot(model.history.history[metric])418ax[i].plot(model.history.history["val_" + metric])419ax[i].set_title("Model {}".format(metric))420ax[i].set_xlabel("epochs")421ax[i].set_ylabel(metric)422ax[i].legend(["train", "val"])423424"""425## Make predictions on a single CT scan426"""427428# Load best weights.429model.load_weights("3d_image_classification.keras")430prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]431scores = [1 - prediction[0], prediction[0]]432433class_names = ["normal", "abnormal"]434for score, name in zip(scores, class_names):435print(436"This model is %.2f percent confident that CT scan is %s"437% ((100 * score), name)438)439440441