Path: blob/master/examples/nlp/tweet-classification-using-tfdf.py
3507 views
"""1Title: Text classification using Decision Forests and pretrained embeddings2Author: Gitesh Chawda3Date created: 09/05/20224Last modified: 09/05/20225Description: Using Tensorflow Decision Forests for text classification.6Accelerator: GPU7"""89"""10## Introduction1112[TensorFlow Decision Forests](https://www.tensorflow.org/decision_forests) (TF-DF)13is a collection of state-of-the-art algorithms for Decision Forest models that are14compatible with Keras APIs. The module includes Random Forests, Gradient Boosted Trees,15and CART, and can be used for regression, classification, and ranking tasks.1617In this example we will use Gradient Boosted Trees with pretrained embeddings to18classify disaster-related tweets.1920### See also:2122- [TF-DF beginner tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab)23- [TF-DF intermediate tutorial](https://www.tensorflow.org/decision_forests/tutorials/intermediate_colab).24"""2526"""27Install Tensorflow Decision Forest using following command :28`pip install tensorflow_decision_forests`29"""303132"""33## Imports34"""3536import pandas as pd37import numpy as np38import tensorflow as tf39from tensorflow import keras40import tensorflow_hub as hub41from tensorflow.keras import layers42import tensorflow_decision_forests as tfdf43import matplotlib.pyplot as plt4445"""46## Get the data4748The Dataset is available on [Kaggle](https://www.kaggle.com/c/nlp-getting-started)4950Dataset description:5152**Files:**5354- train.csv: the training set5556**Columns:**5758- id: a unique identifier for each tweet59- text: the text of the tweet60- location: the location the tweet was sent from (may be blank)61- keyword: a particular keyword from the tweet (may be blank)62- target: in train.csv only, this denotes whether a tweet is about a real disaster (1) or not (0)63"""6465# Turn .csv files into pandas DataFrame's66df = pd.read_csv(67"https://raw.githubusercontent.com/IMvision12/Tweets-Classification-NLP/main/train.csv"68)69print(df.head())7071"""72The dataset includes 7613 samples with 5 columns:73"""7475print(f"Training dataset shape: {df.shape}")7677"""78Shuffling and dropping unnecessary columns:79"""8081df_shuffled = df.sample(frac=1, random_state=42)82# Dropping id, keyword and location columns as these columns consists of mostly nan values83# we will be using only text and target columns84df_shuffled.drop(["id", "keyword", "location"], axis=1, inplace=True)85df_shuffled.reset_index(inplace=True, drop=True)86print(df_shuffled.head())8788"""89Printing information about the shuffled dataframe:90"""9192print(df_shuffled.info())9394"""95Total number of "disaster" and "non-disaster" tweets:96"""9798print(99"Total Number of disaster and non-disaster tweets: "100f"{df_shuffled.target.value_counts()}"101)102103"""104Let's preview a few samples:105"""106107for index, example in df_shuffled[:5].iterrows():108print(f"Example #{index}")109print(f"\tTarget : {example['target']}")110print(f"\tText : {example['text']}")111112"""113Splitting dataset into training and test sets:114"""115test_df = df_shuffled.sample(frac=0.1, random_state=42)116train_df = df_shuffled.drop(test_df.index)117print(f"Using {len(train_df)} samples for training and {len(test_df)} for validation")118119"""120Total number of "disaster" and "non-disaster" tweets in the training data:121"""122print(train_df["target"].value_counts())123124"""125Total number of "disaster" and "non-disaster" tweets in the test data:126"""127128print(test_df["target"].value_counts())129130"""131## Convert data to a `tf.data.Dataset`132"""133134135def create_dataset(dataframe):136dataset = tf.data.Dataset.from_tensor_slices(137(dataframe["text"].to_numpy(), dataframe["target"].to_numpy())138)139dataset = dataset.batch(100)140dataset = dataset.prefetch(tf.data.AUTOTUNE)141return dataset142143144train_ds = create_dataset(train_df)145test_ds = create_dataset(test_df)146147"""148## Downloading pretrained embeddings149150The Universal Sentence Encoder embeddings encode text into high-dimensional vectors that can be151used for text classification, semantic similarity, clustering and other natural language152tasks. They're trained on a variety of data sources and a variety of tasks. Their input is153variable-length English text and their output is a 512 dimensional vector.154155To learn more about these pretrained embeddings, see156[Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder/4).157158"""159160sentence_encoder_layer = hub.KerasLayer(161"https://tfhub.dev/google/universal-sentence-encoder/4"162)163164"""165## Creating our models166167We create two models. In the first model (model_1) raw text will be first encoded via168pretrained embeddings and then passed to a Gradient Boosted Tree model for169classification. In the second model (model_2) raw text will be directly passed to170the Gradient Boosted Trees model.171"""172173"""174Building model_1175"""176177inputs = layers.Input(shape=(), dtype=tf.string)178outputs = sentence_encoder_layer(inputs)179preprocessor = keras.Model(inputs=inputs, outputs=outputs)180model_1 = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor)181182"""183Building model_2184"""185186model_2 = tfdf.keras.GradientBoostedTreesModel()187188"""189## Train the models190191We compile our model by passing the metrics `Accuracy`, `Recall`, `Precision` and192`AUC`. When it comes to the loss, TF-DF automatically detects the best loss for the task193(Classification or regression). It is printed in the model summary.194195Also, because they're batch-training models rather than mini-batch gradient descent models,196TF-DF models do not need a validation dataset to monitor overfitting, or to stop197training early. Some algorithms do not use a validation dataset (e.g. Random Forest)198while some others do (e.g. Gradient Boosted Trees). If a validation dataset is199needed, it will be extracted automatically from the training dataset.200"""201202# Compiling model_1203model_1.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])204# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset205model_1.fit(train_ds)206207# Compiling model_2208model_2.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])209# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset210model_2.fit(train_ds)211212"""213Prints training logs of model_1214"""215216logs_1 = model_1.make_inspector().training_logs()217print(logs_1)218219"""220Prints training logs of model_2221"""222223logs_2 = model_2.make_inspector().training_logs()224print(logs_2)225226"""227The model.summary() method prints a variety of information about your decision tree model, including model type, task, input features, and feature importance.228"""229230print("model_1 summary: ")231print(model_1.summary())232print()233print("model_2 summary: ")234print(model_2.summary())235236"""237## Plotting training metrics238"""239240241def plot_curve(logs):242plt.figure(figsize=(12, 4))243244plt.subplot(1, 2, 1)245plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])246plt.xlabel("Number of trees")247plt.ylabel("Accuracy")248249plt.subplot(1, 2, 2)250plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])251plt.xlabel("Number of trees")252plt.ylabel("Loss")253254plt.show()255256257plot_curve(logs_1)258plot_curve(logs_2)259260"""261## Evaluating on test data262"""263264results = model_1.evaluate(test_ds, return_dict=True, verbose=0)265print("model_1 Evaluation: \n")266for name, value in results.items():267print(f"{name}: {value:.4f}")268269results = model_2.evaluate(test_ds, return_dict=True, verbose=0)270print("model_2 Evaluation: \n")271for name, value in results.items():272print(f"{name}: {value:.4f}")273274"""275## Predicting on validation data276"""277278test_df.reset_index(inplace=True, drop=True)279for index, row in test_df.iterrows():280text = tf.expand_dims(row["text"], axis=0)281preds = model_1.predict_step(text)282preds = tf.squeeze(tf.round(preds))283print(f"Text: {row['text']}")284print(f"Prediction: {int(preds)}")285print(f"Ground Truth : {row['target']}")286if index == 10:287break288289"""290## Concluding remarks291292The TensorFlow Decision Forests package provides powerful models293that work especially well with structured data. In our experiments,294the Gradient Boosted Tree model with pretrained embeddings achieved 81.6%295test accuracy while the plain Gradient Boosted Tree model had 54.4% accuracy.296"""297298299