Path: blob/master/examples/nlp/text_extraction_with_bert.py
3507 views
"""1Title: Text Extraction with BERT2Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)3Date created: 2020/05/234Last modified: 2020/05/235Description: Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.6Accelerator: TPU7"""89"""10## Introduction1112This demonstration uses SQuAD (Stanford Question-Answering Dataset).13In SQuAD, an input consists of a question, and a paragraph for context.14The goal is to find the span of text in the paragraph that answers the question.15We evaluate our performance on this data with the "Exact Match" metric,16which measures the percentage of predictions that exactly match any one of the17ground-truth answers.1819We fine-tune a BERT model to perform this task as follows:20211. Feed the context and the question as inputs to BERT.222. Take two vectors S and T with dimensions equal to that of23hidden states in BERT.243. Compute the probability of each token being the start and end of25the answer span. The probability of a token being the start of26the answer is given by a dot product between S and the representation27of the token in the last layer of BERT, followed by a softmax over all tokens.28The probability of a token being the end of the answer is computed29similarly with the vector T.304. Fine-tune BERT and learn S and T along the way.3132**References:**3334- [BERT](https://arxiv.org/abs/1810.04805)35- [SQuAD](https://arxiv.org/abs/1606.05250)36"""37"""38## Setup39"""40import os41import re42import json43import string44import numpy as np45import tensorflow as tf46from tensorflow import keras47from tensorflow.keras import layers48from tokenizers import BertWordPieceTokenizer49from transformers import BertTokenizer, TFBertModel, BertConfig5051max_len = 38452configuration = BertConfig() # default parameters and configuration for BERT5354"""55## Set-up BERT tokenizer56"""57# Save the slow pretrained tokenizer58slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")59save_path = "bert_base_uncased/"60if not os.path.exists(save_path):61os.makedirs(save_path)62slow_tokenizer.save_pretrained(save_path)6364# Load the fast tokenizer from saved file65tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)6667"""68## Load the data69"""70train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"71train_path = keras.utils.get_file("train.json", train_data_url)72eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"73eval_path = keras.utils.get_file("eval.json", eval_data_url)7475"""76## Preprocess the data77781. Go through the JSON file and store every record as a `SquadExample` object.792. Go through each `SquadExample` and create `x_train, y_train, x_eval, y_eval`.80"""818283class SquadExample:84def __init__(self, question, context, start_char_idx, answer_text, all_answers):85self.question = question86self.context = context87self.start_char_idx = start_char_idx88self.answer_text = answer_text89self.all_answers = all_answers90self.skip = False9192def preprocess(self):93context = self.context94question = self.question95answer_text = self.answer_text96start_char_idx = self.start_char_idx9798# Clean context, answer and question99context = " ".join(str(context).split())100question = " ".join(str(question).split())101answer = " ".join(str(answer_text).split())102103# Find end character index of answer in context104end_char_idx = start_char_idx + len(answer)105if end_char_idx >= len(context):106self.skip = True107return108109# Mark the character indexes in context that are in answer110is_char_in_ans = [0] * len(context)111for idx in range(start_char_idx, end_char_idx):112is_char_in_ans[idx] = 1113114# Tokenize context115tokenized_context = tokenizer.encode(context)116117# Find tokens that were created from answer characters118ans_token_idx = []119for idx, (start, end) in enumerate(tokenized_context.offsets):120if sum(is_char_in_ans[start:end]) > 0:121ans_token_idx.append(idx)122123if len(ans_token_idx) == 0:124self.skip = True125return126127# Find start and end token index for tokens from answer128start_token_idx = ans_token_idx[0]129end_token_idx = ans_token_idx[-1]130131# Tokenize question132tokenized_question = tokenizer.encode(question)133134# Create inputs135input_ids = tokenized_context.ids + tokenized_question.ids[1:]136token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(137tokenized_question.ids[1:]138)139attention_mask = [1] * len(input_ids)140141# Pad and create attention masks.142# Skip if truncation is needed143padding_length = max_len - len(input_ids)144if padding_length > 0: # pad145input_ids = input_ids + ([0] * padding_length)146attention_mask = attention_mask + ([0] * padding_length)147token_type_ids = token_type_ids + ([0] * padding_length)148elif padding_length < 0: # skip149self.skip = True150return151152self.input_ids = input_ids153self.token_type_ids = token_type_ids154self.attention_mask = attention_mask155self.start_token_idx = start_token_idx156self.end_token_idx = end_token_idx157self.context_token_to_char = tokenized_context.offsets158159160with open(train_path) as f:161raw_train_data = json.load(f)162163with open(eval_path) as f:164raw_eval_data = json.load(f)165166167def create_squad_examples(raw_data):168squad_examples = []169for item in raw_data["data"]:170for para in item["paragraphs"]:171context = para["context"]172for qa in para["qas"]:173question = qa["question"]174answer_text = qa["answers"][0]["text"]175all_answers = [_["text"] for _ in qa["answers"]]176start_char_idx = qa["answers"][0]["answer_start"]177squad_eg = SquadExample(178question, context, start_char_idx, answer_text, all_answers179)180squad_eg.preprocess()181squad_examples.append(squad_eg)182return squad_examples183184185def create_inputs_targets(squad_examples):186dataset_dict = {187"input_ids": [],188"token_type_ids": [],189"attention_mask": [],190"start_token_idx": [],191"end_token_idx": [],192}193for item in squad_examples:194if item.skip == False:195for key in dataset_dict:196dataset_dict[key].append(getattr(item, key))197for key in dataset_dict:198dataset_dict[key] = np.array(dataset_dict[key])199200x = [201dataset_dict["input_ids"],202dataset_dict["token_type_ids"],203dataset_dict["attention_mask"],204]205y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]206return x, y207208209train_squad_examples = create_squad_examples(raw_train_data)210x_train, y_train = create_inputs_targets(train_squad_examples)211print(f"{len(train_squad_examples)} training points created.")212213eval_squad_examples = create_squad_examples(raw_eval_data)214x_eval, y_eval = create_inputs_targets(eval_squad_examples)215print(f"{len(eval_squad_examples)} evaluation points created.")216217"""218Create the Question-Answering Model using BERT and Functional API219"""220221222def create_model():223## BERT encoder224encoder = TFBertModel.from_pretrained("bert-base-uncased")225226## QA Model227input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)228token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)229attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)230embedding = encoder(231input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask232)[0]233234start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)235start_logits = layers.Flatten()(start_logits)236237end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)238end_logits = layers.Flatten()(end_logits)239240start_probs = layers.Activation(keras.activations.softmax)(start_logits)241end_probs = layers.Activation(keras.activations.softmax)(end_logits)242243model = keras.Model(244inputs=[input_ids, token_type_ids, attention_mask],245outputs=[start_probs, end_probs],246)247loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)248optimizer = keras.optimizers.Adam(lr=5e-5)249model.compile(optimizer=optimizer, loss=[loss, loss])250return model251252253"""254This code should preferably be run on Google Colab TPU runtime.255With Colab TPUs, each epoch will take 5-6 minutes.256"""257use_tpu = True258if use_tpu:259# Create distribution strategy260tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()261strategy = tf.distribute.TPUStrategy(tpu)262263# Create model264with strategy.scope():265model = create_model()266else:267model = create_model()268269model.summary()270271"""272## Create evaluation Callback273274This callback will compute the exact match score using the validation data275after every epoch.276"""277278279def normalize_text(text):280text = text.lower()281282# Remove punctuations283exclude = set(string.punctuation)284text = "".join(ch for ch in text if ch not in exclude)285286# Remove articles287regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)288text = re.sub(regex, " ", text)289290# Remove extra white space291text = " ".join(text.split())292return text293294295class ExactMatch(keras.callbacks.Callback):296"""297Each `SquadExample` object contains the character level offsets for each token298in its input paragraph. We use them to get back the span of text corresponding299to the tokens between our predicted start and end tokens.300All the ground-truth answers are also present in each `SquadExample` object.301We calculate the percentage of data points where the span of text obtained302from model predictions matches one of the ground-truth answers.303"""304305def __init__(self, x_eval, y_eval):306self.x_eval = x_eval307self.y_eval = y_eval308309def on_epoch_end(self, epoch, logs=None):310pred_start, pred_end = self.model.predict(self.x_eval)311count = 0312eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]313for idx, (start, end) in enumerate(zip(pred_start, pred_end)):314squad_eg = eval_examples_no_skip[idx]315offsets = squad_eg.context_token_to_char316start = np.argmax(start)317end = np.argmax(end)318if start >= len(offsets):319continue320pred_char_start = offsets[start][0]321if end < len(offsets):322pred_char_end = offsets[end][1]323pred_ans = squad_eg.context[pred_char_start:pred_char_end]324else:325pred_ans = squad_eg.context[pred_char_start:]326327normalized_pred_ans = normalize_text(pred_ans)328normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]329if normalized_pred_ans in normalized_true_ans:330count += 1331acc = count / len(self.y_eval[0])332print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")333334335"""336## Train and Evaluate337"""338exact_match_callback = ExactMatch(x_eval, y_eval)339model.fit(340x_train,341y_train,342epochs=1, # For demonstration, 3 epochs are recommended343verbose=2,344batch_size=64,345callbacks=[exact_match_callback],346)347348349