Path: blob/master/examples/vision/attention_mil_classification.py
3507 views
"""1Title: Classification using Attention-based Deep Multiple Instance Learning (MIL).2Author: [Mohamad Jaber](https://www.linkedin.com/in/mohamadjaber1/)3Date created: 2021/08/164Last modified: 2021/11/255Description: MIL approach to classify bags of instances and get their individual instance score.6Accelerator: GPU7"""89"""10## Introduction1112### What is Multiple Instance Learning (MIL)?1314Usually, with supervised learning algorithms, the learner receives labels for a set of15instances. In the case of MIL, the learner receives labels for a set of bags, each of which16contains a set of instances. The bag is labeled positive if it contains at least17one positive instance, and negative if it does not contain any.1819### Motivation2021It is often assumed in image classification tasks that each image clearly represents a22class label. In medical imaging (e.g. computational pathology, etc.) an *entire image*23is represented by a single class label (cancerous/non-cancerous) or a region of interest24could be given. However, one will be interested in knowing which patterns in the image25is actually causing it to belong to that class. In this context, the image(s) will be26divided and the subimages will form the bag of instances.2728Therefore, the goals are to:29301. Learn a model to predict a class label for a bag of instances.312. Find out which instances within the bag caused a position class label32prediction.3334### Implementation3536The following steps describe how the model works:37381. The feature extractor layers extract feature embeddings.392. The embeddings are fed into the MIL attention layer to get40the attention scores. The layer is designed as permutation-invariant.413. Input features and their corresponding attention scores are multiplied together.424. The resulting output is passed to a softmax function for classification.4344### References4546- [Attention-based Deep Multiple Instance Learning](https://arxiv.org/abs/1802.04712).47- Some of the attention operator code implementation was inspired from https://github.com/utayao/Atten_Deep_MIL.48- Imbalanced data [tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data)49by TensorFlow.5051"""52"""53## Setup54"""5556import numpy as np57import keras58from keras import layers59from keras import ops60from tqdm import tqdm61from matplotlib import pyplot as plt6263plt.style.use("ggplot")6465"""66## Create dataset6768We will create a set of bags and assign their labels according to their contents.69If at least one positive instance70is available in a bag, the bag is considered as a positive bag. If it does not contain any71positive instance, the bag will be considered as negative.7273### Configuration parameters7475- `POSITIVE_CLASS`: The desired class to be kept in the positive bag.76- `BAG_COUNT`: The number of training bags.77- `VAL_BAG_COUNT`: The number of validation bags.78- `BAG_SIZE`: The number of instances in a bag.79- `PLOT_SIZE`: The number of bags to plot.80- `ENSEMBLE_AVG_COUNT`: The number of models to create and average together. (Optional:81often results in better performance - set to 1 for single model)82"""8384POSITIVE_CLASS = 185BAG_COUNT = 100086VAL_BAG_COUNT = 30087BAG_SIZE = 388PLOT_SIZE = 389ENSEMBLE_AVG_COUNT = 19091"""92### Prepare bags9394Since the attention operator is a permutation-invariant operator, an instance with a95positive class label is randomly placed among the instances in the positive bag.96"""979899def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):100# Set up bags.101bags = []102bag_labels = []103104# Normalize input data.105input_data = np.divide(input_data, 255.0)106107# Count positive samples.108count = 0109110for _ in range(bag_count):111# Pick a fixed size random subset of samples.112index = np.random.choice(input_data.shape[0], instance_count, replace=False)113instances_data = input_data[index]114instances_labels = input_labels[index]115116# By default, all bags are labeled as 0.117bag_label = 0118119# Check if there is at least a positive class in the bag.120if positive_class in instances_labels:121# Positive bag will be labeled as 1.122bag_label = 1123count += 1124125bags.append(instances_data)126bag_labels.append(np.array([bag_label]))127128print(f"Positive bags: {count}")129print(f"Negative bags: {bag_count - count}")130131return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))132133134# Load the MNIST dataset.135(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()136137# Create training data.138train_data, train_labels = create_bags(139x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE140)141142# Create validation data.143val_data, val_labels = create_bags(144x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE145)146147"""148## Create the model149150We will now build the attention layer, prepare some utilities, then build and train the151entire model.152153### Attention operator implementation154155The output size of this layer is decided by the size of a single bag.156157The attention mechanism uses a weighted average of instances in a bag, in which the sum158of the weights must equal to 1 (invariant of the bag size).159160The weight matrices (parameters) are **w** and **v**. To include positive and negative161values, hyperbolic tangent element-wise non-linearity is utilized.162163A **Gated attention mechanism** can be used to deal with complex relations. Another weight164matrix, **u**, is added to the computation.165A sigmoid non-linearity is used to overcome approximately linear behavior for *x* ∈ [−1, 1]166by hyperbolic tangent non-linearity.167"""168169170class MILAttentionLayer(layers.Layer):171"""Implementation of the attention-based Deep MIL layer.172173Args:174weight_params_dim: Positive Integer. Dimension of the weight matrix.175kernel_initializer: Initializer for the `kernel` matrix.176kernel_regularizer: Regularizer function applied to the `kernel` matrix.177use_gated: Boolean, whether or not to use the gated mechanism.178179Returns:180List of 2D tensors with BAG_SIZE length.181The tensors are the attention scores after softmax with shape `(batch_size, 1)`.182"""183184def __init__(185self,186weight_params_dim,187kernel_initializer="glorot_uniform",188kernel_regularizer=None,189use_gated=False,190**kwargs,191):192super().__init__(**kwargs)193194self.weight_params_dim = weight_params_dim195self.use_gated = use_gated196197self.kernel_initializer = keras.initializers.get(kernel_initializer)198self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)199200self.v_init = self.kernel_initializer201self.w_init = self.kernel_initializer202self.u_init = self.kernel_initializer203204self.v_regularizer = self.kernel_regularizer205self.w_regularizer = self.kernel_regularizer206self.u_regularizer = self.kernel_regularizer207208def build(self, input_shape):209# Input shape.210# List of 2D tensors with shape: (batch_size, input_dim).211input_dim = input_shape[0][1]212213self.v_weight_params = self.add_weight(214shape=(input_dim, self.weight_params_dim),215initializer=self.v_init,216name="v",217regularizer=self.v_regularizer,218trainable=True,219)220221self.w_weight_params = self.add_weight(222shape=(self.weight_params_dim, 1),223initializer=self.w_init,224name="w",225regularizer=self.w_regularizer,226trainable=True,227)228229if self.use_gated:230self.u_weight_params = self.add_weight(231shape=(input_dim, self.weight_params_dim),232initializer=self.u_init,233name="u",234regularizer=self.u_regularizer,235trainable=True,236)237else:238self.u_weight_params = None239240self.input_built = True241242def call(self, inputs):243# Assigning variables from the number of inputs.244instances = [self.compute_attention_scores(instance) for instance in inputs]245246# Stack instances into a single tensor.247instances = ops.stack(instances)248249# Apply softmax over instances such that the output summation is equal to 1.250alpha = ops.softmax(instances, axis=0)251252# Split to recreate the same array of tensors we had as inputs.253return [alpha[i] for i in range(alpha.shape[0])]254255def compute_attention_scores(self, instance):256# Reserve in-case "gated mechanism" used.257original_instance = instance258259# tanh(v*h_k^T)260instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))261262# for learning non-linear relations efficiently.263if self.use_gated:264instance = instance * ops.sigmoid(265ops.tensordot(original_instance, self.u_weight_params, axes=1)266)267268# w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))269return ops.tensordot(instance, self.w_weight_params, axes=1)270271272"""273## Visualizer tool274275Plot the number of bags (given by `PLOT_SIZE`) with respect to the class.276277Moreover, if activated, the class label prediction with its associated instance score278for each bag (after the model has been trained) can be seen.279"""280281282def plot(data, labels, bag_class, predictions=None, attention_weights=None):283""" "Utility for plotting bags and attention weights.284285Args:286data: Input data that contains the bags of instances.287labels: The associated bag labels of the input data.288bag_class: String name of the desired bag class.289The options are: "positive" or "negative".290predictions: Class labels model predictions.291If you don't specify anything, ground truth labels will be used.292attention_weights: Attention weights for each instance within the input data.293If you don't specify anything, the values won't be displayed.294"""295return ## TODO296labels = np.array(labels).reshape(-1)297298if bag_class == "positive":299if predictions is not None:300labels = np.where(predictions.argmax(1) == 1)[0]301bags = np.array(data)[:, labels[0:PLOT_SIZE]]302303else:304labels = np.where(labels == 1)[0]305bags = np.array(data)[:, labels[0:PLOT_SIZE]]306307elif bag_class == "negative":308if predictions is not None:309labels = np.where(predictions.argmax(1) == 0)[0]310bags = np.array(data)[:, labels[0:PLOT_SIZE]]311else:312labels = np.where(labels == 0)[0]313bags = np.array(data)[:, labels[0:PLOT_SIZE]]314315else:316print(f"There is no class {bag_class}")317return318319print(f"The bag class label is {bag_class}")320for i in range(PLOT_SIZE):321figure = plt.figure(figsize=(8, 8))322print(f"Bag number: {labels[i]}")323for j in range(BAG_SIZE):324image = bags[j][i]325figure.add_subplot(1, BAG_SIZE, j + 1)326plt.grid(False)327if attention_weights is not None:328plt.title(np.around(attention_weights[labels[i]][j], 2))329plt.imshow(image)330plt.show()331332333# Plot some of validation data bags per class.334plot(val_data, val_labels, "positive")335plot(val_data, val_labels, "negative")336337"""338## Create model339340First we will create some embeddings per instance, invoke the attention operator and then341use the softmax function to output the class probabilities.342"""343344345def create_model(instance_shape):346# Extract features from inputs.347inputs, embeddings = [], []348shared_dense_layer_1 = layers.Dense(128, activation="relu")349shared_dense_layer_2 = layers.Dense(64, activation="relu")350for _ in range(BAG_SIZE):351inp = layers.Input(instance_shape)352flatten = layers.Flatten()(inp)353dense_1 = shared_dense_layer_1(flatten)354dense_2 = shared_dense_layer_2(dense_1)355inputs.append(inp)356embeddings.append(dense_2)357358# Invoke the attention layer.359alpha = MILAttentionLayer(360weight_params_dim=256,361kernel_regularizer=keras.regularizers.L2(0.01),362use_gated=True,363name="alpha",364)(embeddings)365366# Multiply attention weights with the input layers.367multiply_layers = [368layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))369]370371# Concatenate layers.372concat = layers.concatenate(multiply_layers, axis=1)373374# Classification output node.375output = layers.Dense(2, activation="softmax")(concat)376377return keras.Model(inputs, output)378379380"""381## Class weights382383Since this kind of problem could simply turn into imbalanced data classification problem,384class weighting should be considered.385386Let's say there are 1000 bags. There often could be cases were ~90 % of the bags do not387contain any positive label and ~10 % do.388Such data can be referred to as **Imbalanced data**.389390Using class weights, the model will tend to give a higher weight to the rare class.391"""392393394def compute_class_weights(labels):395# Count number of positive and negative bags.396negative_count = len(np.where(labels == 0)[0])397positive_count = len(np.where(labels == 1)[0])398total_count = negative_count + positive_count399400# Build class weight dictionary.401return {4020: (1 / negative_count) * (total_count / 2),4031: (1 / positive_count) * (total_count / 2),404}405406407"""408## Build and train model409410The model is built and trained in this section.411"""412413414def train(train_data, train_labels, val_data, val_labels, model):415# Train model.416# Prepare callbacks.417# Path where to save best weights.418419# Take the file name from the wrapper.420file_path = "/tmp/best_model.weights.h5"421422# Initialize model checkpoint callback.423model_checkpoint = keras.callbacks.ModelCheckpoint(424file_path,425monitor="val_loss",426verbose=0,427mode="min",428save_best_only=True,429save_weights_only=True,430)431432# Initialize early stopping callback.433# The model performance is monitored across the validation data and stops training434# when the generalization error cease to decrease.435early_stopping = keras.callbacks.EarlyStopping(436monitor="val_loss", patience=10, mode="min"437)438439# Compile model.440model.compile(441optimizer="adam",442loss="sparse_categorical_crossentropy",443metrics=["accuracy"],444)445446# Fit model.447model.fit(448train_data,449train_labels,450validation_data=(val_data, val_labels),451epochs=20,452class_weight=compute_class_weights(train_labels),453batch_size=1,454callbacks=[early_stopping, model_checkpoint],455verbose=0,456)457458# Load best weights.459model.load_weights(file_path)460461return model462463464# Building model(s).465instance_shape = train_data[0][0].shape466models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]467468# Show single model architecture.469print(models[0].summary())470471# Training model(s).472trained_models = [473train(train_data, train_labels, val_data, val_labels, model)474for model in tqdm(models)475]476477"""478## Model evaluation479480The models are now ready for evaluation.481With each model we also create an associated intermediate model to get the482weights from the attention layer.483484We will compute a prediction for each of our `ENSEMBLE_AVG_COUNT` models, and485average them together for our final prediction.486"""487488489def predict(data, labels, trained_models):490# Collect info per model.491models_predictions = []492models_attention_weights = []493models_losses = []494models_accuracies = []495496for model in trained_models:497# Predict output classes on data.498predictions = model.predict(data)499models_predictions.append(predictions)500501# Create intermediate model to get MIL attention layer weights.502intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)503504# Predict MIL attention layer weights.505intermediate_predictions = intermediate_model.predict(data)506507attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))508models_attention_weights.append(attention_weights)509510loss, accuracy = model.evaluate(data, labels, verbose=0)511models_losses.append(loss)512models_accuracies.append(accuracy)513514print(515f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"516f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."517)518519return (520np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,521np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,522)523524525# Evaluate and predict classes and attention scores on validation data.526class_predictions, attention_params = predict(val_data, val_labels, trained_models)527528# Plot some results from our validation data.529plot(530val_data,531val_labels,532"positive",533predictions=class_predictions,534attention_weights=attention_params,535)536plot(537val_data,538val_labels,539"negative",540predictions=class_predictions,541attention_weights=attention_params,542)543544"""545## Conclusion546547From the above plot, you can notice that the weights always sum to 1. In a548positively predict bag, the instance which resulted in the positive labeling will have549a substantially higher attention score than the rest of the bag. However, in a negatively550predicted bag, there are two cases:551552* All instances will have approximately similar scores.553* An instance will have relatively higher score (but not as high as of a positive instance).554This is because the feature space of this instance is close to that of the positive instance.555556## Remarks557558- If the model is overfit, the weights will be equally distributed for all bags. Hence,559the regularization techniques are necessary.560- In the paper, the bag sizes can differ from one bag to another. For simplicity, the561bag sizes are fixed here.562- In order not to rely on the random initial weights of a single model, averaging ensemble563methods should be considered.564"""565566567