Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/md/text_extraction_with_bert.md
3508 views

Text Extraction with BERT

Author: Apoorv Nandan
Date created: 2020/05/23
Last modified: 2020/05/23

View in Colab GitHub source

Description: Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.


Introduction

This demonstration uses SQuAD (Stanford Question-Answering Dataset). In SQuAD, an input consists of a question, and a paragraph for context. The goal is to find the span of text in the paragraph that answers the question. We evaluate our performance on this data with the "Exact Match" metric, which measures the percentage of predictions that exactly match any one of the ground-truth answers.

We fine-tune a BERT model to perform this task as follows:

  1. Feed the context and the question as inputs to BERT.

  2. Take two vectors S and T with dimensions equal to that of hidden states in BERT.

  3. Compute the probability of each token being the start and end of the answer span. The probability of a token being the start of the answer is given by a dot product between S and the representation of the token in the last layer of BERT, followed by a softmax over all tokens. The probability of a token being the end of the answer is computed similarly with the vector T.

  4. Fine-tune BERT and learn S and T along the way.

References:

Setup

import os import re import json import string import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from tokenizers import BertWordPieceTokenizer from transformers import BertTokenizer, TFBertModel, BertConfig max_len = 384 configuration = BertConfig() # default parameters and configuration for BERT

Set-up BERT tokenizer

# Save the slow pretrained tokenizer slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") save_path = "bert_base_uncased/" if not os.path.exists(save_path): os.makedirs(save_path) slow_tokenizer.save_pretrained(save_path) # Load the fast tokenizer from saved file tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)

Load the data

train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json" train_path = keras.utils.get_file("train.json", train_data_url) eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json" eval_path = keras.utils.get_file("eval.json", eval_data_url)

Preprocess the data

  1. Go through the JSON file and store every record as a SquadExample object.

  2. Go through each SquadExample and create x_train, y_train, x_eval, y_eval.

class SquadExample: def __init__(self, question, context, start_char_idx, answer_text, all_answers): self.question = question self.context = context self.start_char_idx = start_char_idx self.answer_text = answer_text self.all_answers = all_answers self.skip = False def preprocess(self): context = self.context question = self.question answer_text = self.answer_text start_char_idx = self.start_char_idx # Clean context, answer and question context = " ".join(str(context).split()) question = " ".join(str(question).split()) answer = " ".join(str(answer_text).split()) # Find end character index of answer in context end_char_idx = start_char_idx + len(answer) if end_char_idx >= len(context): self.skip = True return # Mark the character indexes in context that are in answer is_char_in_ans = [0] * len(context) for idx in range(start_char_idx, end_char_idx): is_char_in_ans[idx] = 1 # Tokenize context tokenized_context = tokenizer.encode(context) # Find tokens that were created from answer characters ans_token_idx = [] for idx, (start, end) in enumerate(tokenized_context.offsets): if sum(is_char_in_ans[start:end]) > 0: ans_token_idx.append(idx) if len(ans_token_idx) == 0: self.skip = True return # Find start and end token index for tokens from answer start_token_idx = ans_token_idx[0] end_token_idx = ans_token_idx[-1] # Tokenize question tokenized_question = tokenizer.encode(question) # Create inputs input_ids = tokenized_context.ids + tokenized_question.ids[1:] token_type_ids = [0] * len(tokenized_context.ids) + [1] * len( tokenized_question.ids[1:] ) attention_mask = [1] * len(input_ids) # Pad and create attention masks. # Skip if truncation is needed padding_length = max_len - len(input_ids) if padding_length > 0: # pad input_ids = input_ids + ([0] * padding_length) attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) elif padding_length < 0: # skip self.skip = True return self.input_ids = input_ids self.token_type_ids = token_type_ids self.attention_mask = attention_mask self.start_token_idx = start_token_idx self.end_token_idx = end_token_idx self.context_token_to_char = tokenized_context.offsets with open(train_path) as f: raw_train_data = json.load(f) with open(eval_path) as f: raw_eval_data = json.load(f) def create_squad_examples(raw_data): squad_examples = [] for item in raw_data["data"]: for para in item["paragraphs"]: context = para["context"] for qa in para["qas"]: question = qa["question"] answer_text = qa["answers"][0]["text"] all_answers = [_["text"] for _ in qa["answers"]] start_char_idx = qa["answers"][0]["answer_start"] squad_eg = SquadExample( question, context, start_char_idx, answer_text, all_answers ) squad_eg.preprocess() squad_examples.append(squad_eg) return squad_examples def create_inputs_targets(squad_examples): dataset_dict = { "input_ids": [], "token_type_ids": [], "attention_mask": [], "start_token_idx": [], "end_token_idx": [], } for item in squad_examples: if item.skip == False: for key in dataset_dict: dataset_dict[key].append(getattr(item, key)) for key in dataset_dict: dataset_dict[key] = np.array(dataset_dict[key]) x = [ dataset_dict["input_ids"], dataset_dict["token_type_ids"], dataset_dict["attention_mask"], ] y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]] return x, y train_squad_examples = create_squad_examples(raw_train_data) x_train, y_train = create_inputs_targets(train_squad_examples) print(f"{len(train_squad_examples)} training points created.") eval_squad_examples = create_squad_examples(raw_eval_data) x_eval, y_eval = create_inputs_targets(eval_squad_examples) print(f"{len(eval_squad_examples)} evaluation points created.")
``` 87599 training points created. 10570 evaluation points created.
</div> Create the Question-Answering Model using BERT and Functional API ```python def create_model(): ## BERT encoder encoder = TFBertModel.from_pretrained("bert-base-uncased") ## QA Model input_ids = layers.Input(shape=(max_len,), dtype=tf.int32) token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32) attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32) embedding = encoder( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask )[0] start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding) start_logits = layers.Flatten()(start_logits) end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding) end_logits = layers.Flatten()(end_logits) start_probs = layers.Activation(keras.activations.softmax)(start_logits) end_probs = layers.Activation(keras.activations.softmax)(end_logits) model = keras.Model( inputs=[input_ids, token_type_ids, attention_mask], outputs=[start_probs, end_probs], ) loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) optimizer = keras.optimizers.Adam(lr=5e-5) model.compile(optimizer=optimizer, loss=[loss, loss]) return model

This code should preferably be run on Google Colab TPU runtime. With Colab TPUs, each epoch will take 5-6 minutes.

use_tpu = True if use_tpu: # Create distribution strategy tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() strategy = tf.distribute.TPUStrategy(tpu) # Create model with strategy.scope(): model = create_model() else: model = create_model() model.summary()
``` INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0

INFO:tensorflow:Initializing the TPU system: grpc://10.48.159.170:8470

INFO:tensorflow:Clearing out eager caches

INFO:tensorflow:Finished initializing TPU system.

INFO:tensorflow:Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

Model: "model"


Layer (type) Output Shape Param # Connected to

input_1 (InputLayer) [(None, 384)] 0


input_3 (InputLayer) [(None, 384)] 0


input_2 (InputLayer) [(None, 384)] 0


tf_bert_model (TFBertModel) ((None, 384, 768), ( 109482240 input_1[0][0]


start_logit (Dense) (None, 384, 1) 768 tf_bert_model[0][0]


end_logit (Dense) (None, 384, 1) 768 tf_bert_model[0][0]


flatten (Flatten) (None, 384) 0 start_logit[0][0]


flatten_1 (Flatten) (None, 384) 0 end_logit[0][0]


activation_7 (Activation) (None, 384) 0 flatten[0][0]


activation_8 (Activation) (None, 384) 0 flatten_1[0][0]

Total params: 109,483,776 Trainable params: 109,483,776 Non-trainable params: 0


</div> --- ## Create evaluation Callback This callback will compute the exact match score using the validation data after every epoch. ```python def normalize_text(text): text = text.lower() # Remove punctuations exclude = set(string.punctuation) text = "".join(ch for ch in text if ch not in exclude) # Remove articles regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) text = re.sub(regex, " ", text) # Remove extra white space text = " ".join(text.split()) return text class ExactMatch(keras.callbacks.Callback): """ Each `SquadExample` object contains the character level offsets for each token in its input paragraph. We use them to get back the span of text corresponding to the tokens between our predicted start and end tokens. All the ground-truth answers are also present in each `SquadExample` object. We calculate the percentage of data points where the span of text obtained from model predictions matches one of the ground-truth answers. """ def __init__(self, x_eval, y_eval): self.x_eval = x_eval self.y_eval = y_eval def on_epoch_end(self, epoch, logs=None): pred_start, pred_end = self.model.predict(self.x_eval) count = 0 eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False] for idx, (start, end) in enumerate(zip(pred_start, pred_end)): squad_eg = eval_examples_no_skip[idx] offsets = squad_eg.context_token_to_char start = np.argmax(start) end = np.argmax(end) if start >= len(offsets): continue pred_char_start = offsets[start][0] if end < len(offsets): pred_char_end = offsets[end][1] pred_ans = squad_eg.context[pred_char_start:pred_char_end] else: pred_ans = squad_eg.context[pred_char_start:] normalized_pred_ans = normalize_text(pred_ans) normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers] if normalized_pred_ans in normalized_true_ans: count += 1 acc = count / len(self.y_eval[0]) print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")

Train and Evaluate

exact_match_callback = ExactMatch(x_eval, y_eval) model.fit( x_train, y_train, epochs=1, # For demonstration, 3 epochs are recommended verbose=2, batch_size=64, callbacks=[exact_match_callback], )
``` epoch=1, exact match score=0.78 1346/1346 - 350s - activation_7_loss: 1.3488 - loss: 2.5905 - activation_8_loss: 1.2417

<tensorflow.python.keras.callbacks.History at 0x7fc78b4458d0>

</div>