Path: blob/master/examples/nlp/active_learning_review_classification.py
3507 views
"""1Title: Review Classification using Active Learning2Author: [Darshan Deshpande](https://twitter.com/getdarshan)3Date created: 2021/10/294Last modified: 2024/05/085Description: Demonstrating the advantages of active learning through review classification.6Accelerator: GPU7Converted to Keras 3 by: [Sachin Prasad](https://github.com/sachinprasadhs)8"""910"""11## Introduction1213With the growth of data-centric Machine Learning, Active Learning has grown in popularity14amongst businesses and researchers. Active Learning seeks to progressively15train ML models so that the resultant model requires lesser amount of training data to16achieve competitive scores.1718The structure of an Active Learning pipeline involves a classifier and an oracle. The19oracle is an annotator that cleans, selects, labels the data, and feeds it to the model20when required. The oracle is a trained individual or a group of individuals that21ensure consistency in labeling of new data.2223The process starts with annotating a small subset of the full dataset and training an24initial model. The best model checkpoint is saved and then tested on a balanced test25set. The test set must be carefully sampled because the full training process will be26dependent on it. Once we have the initial evaluation scores, the oracle is tasked with27labeling more samples; the number of data points to be sampled is usually determined by28the business requirements. After that, the newly sampled data is added to the training29set, and the training procedure repeats. This cycle continues until either an30acceptable score is reached or some other business metric is met.3132This tutorial provides a basic demonstration of how Active Learning works by33demonstrating a ratio-based (least confidence) sampling strategy that results in lower34overall false positive and negative rates when compared to a model trained on the entire35dataset. This sampling falls under the domain of *uncertainty sampling*, in which new36datasets are sampled based on the uncertainty that the model outputs for the37corresponding label. In our example, we compare our model's false positive and false38negative rates and annotate the new data based on their ratio.3940Some other sampling techniques include:41421. [Committee sampling](https://www.researchgate.net/publication/51909346_Committee-Based_Sample_Selection_for_Probabilistic_Classifiers):43Using multiple models to vote for the best data points to be sampled442. [Entropy reduction](https://www.researchgate.net/publication/51909346_Committee-Based_Sample_Selection_for_Probabilistic_Classifiers):45Sampling according to an entropy threshold, selecting more of the samples that produce the highest entropy score.463. [Minimum margin based sampling](https://arxiv.org/abs/1906.00025v1):47Selects data points closest to the decision boundary48"""4950"""51## Importing required libraries52"""5354import os5556os.environ["KERAS_BACKEND"] = "tensorflow" # @param ["tensorflow", "jax", "torch"]57import keras58from keras import ops59from keras import layers60import tensorflow_datasets as tfds61import tensorflow as tf62import matplotlib.pyplot as plt63import re64import string6566tfds.disable_progress_bar()6768"""69## Loading and preprocessing the data7071We will be using the IMDB reviews dataset for our experiments. This dataset has 50,00072reviews in total, including training and testing splits. We will merge these splits and73sample our own, balanced training, validation and testing sets.74"""7576dataset = tfds.load(77"imdb_reviews",78split="train + test",79as_supervised=True,80batch_size=-1,81shuffle_files=False,82)83reviews, labels = tfds.as_numpy(dataset)8485print("Total examples:", reviews.shape[0])8687"""88Active learning starts with labeling a subset of data.89For the ratio sampling technique that we will be using, we will need well-balanced training,90validation and testing splits.91"""9293val_split = 250094test_split = 250095train_split = 75009697# Separating the negative and positive samples for manual stratification98x_positives, y_positives = reviews[labels == 1], labels[labels == 1]99x_negatives, y_negatives = reviews[labels == 0], labels[labels == 0]100101# Creating training, validation and testing splits102x_val, y_val = (103tf.concat((x_positives[:val_split], x_negatives[:val_split]), 0),104tf.concat((y_positives[:val_split], y_negatives[:val_split]), 0),105)106x_test, y_test = (107tf.concat(108(109x_positives[val_split : val_split + test_split],110x_negatives[val_split : val_split + test_split],111),1120,113),114tf.concat(115(116y_positives[val_split : val_split + test_split],117y_negatives[val_split : val_split + test_split],118),1190,120),121)122x_train, y_train = (123tf.concat(124(125x_positives[val_split + test_split : val_split + test_split + train_split],126x_negatives[val_split + test_split : val_split + test_split + train_split],127),1280,129),130tf.concat(131(132y_positives[val_split + test_split : val_split + test_split + train_split],133y_negatives[val_split + test_split : val_split + test_split + train_split],134),1350,136),137)138139# Remaining pool of samples are stored separately. These are only labeled as and when required140x_pool_positives, y_pool_positives = (141x_positives[val_split + test_split + train_split :],142y_positives[val_split + test_split + train_split :],143)144x_pool_negatives, y_pool_negatives = (145x_negatives[val_split + test_split + train_split :],146y_negatives[val_split + test_split + train_split :],147)148149# Creating TF Datasets for faster prefetching and parallelization150train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))151val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))152test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))153154pool_negatives = tf.data.Dataset.from_tensor_slices(155(x_pool_negatives, y_pool_negatives)156)157pool_positives = tf.data.Dataset.from_tensor_slices(158(x_pool_positives, y_pool_positives)159)160161print(f"Initial training set size: {len(train_dataset)}")162print(f"Validation set size: {len(val_dataset)}")163print(f"Testing set size: {len(test_dataset)}")164print(f"Unlabeled negative pool: {len(pool_negatives)}")165print(f"Unlabeled positive pool: {len(pool_positives)}")166167"""168### Fitting the `TextVectorization` layer169170Since we are working with text data, we will need to encode the text strings as vectors which171would then be passed through an `Embedding` layer. To make this tokenization process172faster, we use the `map()` function with its parallelization functionality.173"""174175176vectorizer = layers.TextVectorization(1773000, standardize="lower_and_strip_punctuation", output_sequence_length=150178)179# Adapting the dataset180vectorizer.adapt(181train_dataset.map(lambda x, y: x, num_parallel_calls=tf.data.AUTOTUNE).batch(256)182)183184185def vectorize_text(text, label):186text = vectorizer(text)187return text, label188189190train_dataset = train_dataset.map(191vectorize_text, num_parallel_calls=tf.data.AUTOTUNE192).prefetch(tf.data.AUTOTUNE)193pool_negatives = pool_negatives.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)194pool_positives = pool_positives.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)195196val_dataset = val_dataset.batch(256).map(197vectorize_text, num_parallel_calls=tf.data.AUTOTUNE198)199test_dataset = test_dataset.batch(256).map(200vectorize_text, num_parallel_calls=tf.data.AUTOTUNE201)202203"""204## Creating Helper Functions205"""206207208# Helper function for merging new history objects with older ones209def append_history(losses, val_losses, accuracy, val_accuracy, history):210losses = losses + history.history["loss"]211val_losses = val_losses + history.history["val_loss"]212accuracy = accuracy + history.history["binary_accuracy"]213val_accuracy = val_accuracy + history.history["val_binary_accuracy"]214return losses, val_losses, accuracy, val_accuracy215216217# Plotter function218def plot_history(losses, val_losses, accuracies, val_accuracies):219plt.plot(losses)220plt.plot(val_losses)221plt.legend(["train_loss", "val_loss"])222plt.xlabel("Epochs")223plt.ylabel("Loss")224plt.show()225226plt.plot(accuracies)227plt.plot(val_accuracies)228plt.legend(["train_accuracy", "val_accuracy"])229plt.xlabel("Epochs")230plt.ylabel("Accuracy")231plt.show()232233234"""235## Creating the Model236237We create a small bidirectional LSTM model. When using Active Learning, you should make sure238that the model architecture is capable of overfitting to the initial data.239Overfitting gives a strong hint that the model will have enough capacity for240future, unseen data.241"""242243244def create_model():245model = keras.models.Sequential(246[247layers.Input(shape=(150,)),248layers.Embedding(input_dim=3000, output_dim=128),249layers.Bidirectional(layers.LSTM(32, return_sequences=True)),250layers.GlobalMaxPool1D(),251layers.Dense(20, activation="relu"),252layers.Dropout(0.5),253layers.Dense(1, activation="sigmoid"),254]255)256model.summary()257return model258259260"""261## Training on the entire dataset262263To show the effectiveness of Active Learning, we will first train the model on the entire264dataset containing 40,000 labeled samples. This model will be used for comparison later.265"""266267268def train_full_model(full_train_dataset, val_dataset, test_dataset):269model = create_model()270model.compile(271loss="binary_crossentropy",272optimizer="rmsprop",273metrics=[274keras.metrics.BinaryAccuracy(),275keras.metrics.FalseNegatives(),276keras.metrics.FalsePositives(),277],278)279280# We will save the best model at every epoch and load the best one for evaluation on the test set281history = model.fit(282full_train_dataset.batch(256),283epochs=20,284validation_data=val_dataset,285callbacks=[286keras.callbacks.EarlyStopping(patience=4, verbose=1),287keras.callbacks.ModelCheckpoint(288"FullModelCheckpoint.keras", verbose=1, save_best_only=True289),290],291)292293# Plot history294plot_history(295history.history["loss"],296history.history["val_loss"],297history.history["binary_accuracy"],298history.history["val_binary_accuracy"],299)300301# Loading the best checkpoint302model = keras.models.load_model("FullModelCheckpoint.keras")303304print("-" * 100)305print(306"Test set evaluation: ",307model.evaluate(test_dataset, verbose=0, return_dict=True),308)309print("-" * 100)310return model311312313# Sampling the full train dataset to train on314full_train_dataset = (315train_dataset.concatenate(pool_positives)316.concatenate(pool_negatives)317.cache()318.shuffle(20000)319)320321# Training the full model322full_dataset_model = train_full_model(full_train_dataset, val_dataset, test_dataset)323324"""325## Training via Active Learning326327The general process we follow when performing Active Learning is demonstrated below:328329330331The pipeline can be summarized in five parts:3323331. Sample and annotate a small, balanced training dataset3342. Train the model on this small subset3353. Evaluate the model on a balanced testing set3364. If the model satisfies the business criteria, deploy it in a real time setting3375. If it doesn't pass the criteria, sample a few more samples according to the ratio of338false positives and negatives, add them to the training set and repeat from step 2 till339the model passes the tests or till all available data is exhausted.340341For the code below, we will perform sampling using the following formula:<br/>342343344345Active Learning techniques use callbacks extensively for progress tracking. We will be346using model checkpointing and early stopping for this example. The `patience` parameter347for Early Stopping can help minimize overfitting and the time required. We have set it348`patience=4` for now but since the model is robust, we can increase the patience level if349desired.350351Note: We are not loading the checkpoint after the first training iteration. In my352experience working on Active Learning techniques, this helps the model probe the353newly formed loss landscape. Even if the model fails to improve in the second iteration,354we will still gain insight about the possible future false positive and negative rates.355This will help us sample a better set in the next iteration where the model will have a356greater chance to improve.357"""358359360def train_active_learning_models(361train_dataset,362pool_negatives,363pool_positives,364val_dataset,365test_dataset,366num_iterations=3,367sampling_size=5000,368):369370# Creating lists for storing metrics371losses, val_losses, accuracies, val_accuracies = [], [], [], []372373model = create_model()374# We will monitor the false positives and false negatives predicted by our model375# These will decide the subsequent sampling ratio for every Active Learning loop376model.compile(377loss="binary_crossentropy",378optimizer="rmsprop",379metrics=[380keras.metrics.BinaryAccuracy(),381keras.metrics.FalseNegatives(),382keras.metrics.FalsePositives(),383],384)385386# Defining checkpoints.387# The checkpoint callback is reused throughout the training since it only saves the best overall model.388checkpoint = keras.callbacks.ModelCheckpoint(389"AL_Model.keras", save_best_only=True, verbose=1390)391# Here, patience is set to 4. This can be set higher if desired.392early_stopping = keras.callbacks.EarlyStopping(patience=4, verbose=1)393394print(f"Starting to train with {len(train_dataset)} samples")395# Initial fit with a small subset of the training set396history = model.fit(397train_dataset.cache().shuffle(20000).batch(256),398epochs=20,399validation_data=val_dataset,400callbacks=[checkpoint, early_stopping],401)402403# Appending history404losses, val_losses, accuracies, val_accuracies = append_history(405losses, val_losses, accuracies, val_accuracies, history406)407408for iteration in range(num_iterations):409# Getting predictions from previously trained model410predictions = model.predict(test_dataset)411412# Generating labels from the output probabilities413rounded = ops.where(ops.greater(predictions, 0.5), 1, 0)414415# Evaluating the number of zeros and ones incorrrectly classified416_, _, false_negatives, false_positives = model.evaluate(test_dataset, verbose=0)417418print("-" * 100)419print(420f"Number of zeros incorrectly classified: {false_negatives}, Number of ones incorrectly classified: {false_positives}"421)422423# This technique of Active Learning demonstrates ratio based sampling where424# Number of ones/zeros to sample = Number of ones/zeros incorrectly classified / Total incorrectly classified425if false_negatives != 0 and false_positives != 0:426total = false_negatives + false_positives427sample_ratio_ones, sample_ratio_zeros = (428false_positives / total,429false_negatives / total,430)431# In the case where all samples are correctly predicted, we can sample both classes equally432else:433sample_ratio_ones, sample_ratio_zeros = 0.5, 0.5434435print(436f"Sample ratio for positives: {sample_ratio_ones}, Sample ratio for negatives:{sample_ratio_zeros}"437)438439# Sample the required number of ones and zeros440sampled_dataset = pool_negatives.take(441int(sample_ratio_zeros * sampling_size)442).concatenate(pool_positives.take(int(sample_ratio_ones * sampling_size)))443444# Skip the sampled data points to avoid repetition of sample445pool_negatives = pool_negatives.skip(int(sample_ratio_zeros * sampling_size))446pool_positives = pool_positives.skip(int(sample_ratio_ones * sampling_size))447448# Concatenating the train_dataset with the sampled_dataset449train_dataset = train_dataset.concatenate(sampled_dataset).prefetch(450tf.data.AUTOTUNE451)452453print(f"Starting training with {len(train_dataset)} samples")454print("-" * 100)455456# We recompile the model to reset the optimizer states and retrain the model457model.compile(458loss="binary_crossentropy",459optimizer="rmsprop",460metrics=[461keras.metrics.BinaryAccuracy(),462keras.metrics.FalseNegatives(),463keras.metrics.FalsePositives(),464],465)466history = model.fit(467train_dataset.cache().shuffle(20000).batch(256),468validation_data=val_dataset,469epochs=20,470callbacks=[471checkpoint,472keras.callbacks.EarlyStopping(patience=4, verbose=1),473],474)475476# Appending the history477losses, val_losses, accuracies, val_accuracies = append_history(478losses, val_losses, accuracies, val_accuracies, history479)480481# Loading the best model from this training loop482model = keras.models.load_model("AL_Model.keras")483484# Plotting the overall history and evaluating the final model485plot_history(losses, val_losses, accuracies, val_accuracies)486print("-" * 100)487print(488"Test set evaluation: ",489model.evaluate(test_dataset, verbose=0, return_dict=True),490)491print("-" * 100)492493return model494495496active_learning_model = train_active_learning_models(497train_dataset, pool_negatives, pool_positives, val_dataset, test_dataset498)499500"""501## Conclusion502503Active Learning is a growing area of research. This example demonstrates the cost-efficiency504benefits of using Active Learning, as it eliminates the need to annotate large amounts of505data, saving resources.506507The following are some noteworthy observations from this example:5085091. We only require 30,000 samples to reach the same (if not better) scores as the model510trained on the full dataset. This means that in a real life setting, we save the effort511required for annotating 10,000 images!5122. The number of false negatives and false positives are well balanced at the end of the513training as compared to the skewed ratio obtained from the full training. This makes the514model slightly more useful in real life scenarios where both the labels hold equal515importance.516517For further reading about the types of sampling ratios, training techniques or available518open source libraries/implementations, you can refer to the resources below:5195201. [Active Learning Literature Survey](http://burrsettles.com/pub/settles.activelearning.pdf) (Burr Settles, 2010).5212. [modAL](https://github.com/modAL-python/modAL): A Modular Active Learning framework.5223. Google's unofficial [Active Learning playground](https://github.com/google/active-learning).523"""524525526