Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/text_extraction_with_bert.py
3507 views
1
"""
2
Title: Text Extraction with BERT
3
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
4
Date created: 2020/05/23
5
Last modified: 2020/05/23
6
Description: Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.
7
Accelerator: TPU
8
"""
9
10
"""
11
## Introduction
12
13
This demonstration uses SQuAD (Stanford Question-Answering Dataset).
14
In SQuAD, an input consists of a question, and a paragraph for context.
15
The goal is to find the span of text in the paragraph that answers the question.
16
We evaluate our performance on this data with the "Exact Match" metric,
17
which measures the percentage of predictions that exactly match any one of the
18
ground-truth answers.
19
20
We fine-tune a BERT model to perform this task as follows:
21
22
1. Feed the context and the question as inputs to BERT.
23
2. Take two vectors S and T with dimensions equal to that of
24
hidden states in BERT.
25
3. Compute the probability of each token being the start and end of
26
the answer span. The probability of a token being the start of
27
the answer is given by a dot product between S and the representation
28
of the token in the last layer of BERT, followed by a softmax over all tokens.
29
The probability of a token being the end of the answer is computed
30
similarly with the vector T.
31
4. Fine-tune BERT and learn S and T along the way.
32
33
**References:**
34
35
- [BERT](https://arxiv.org/abs/1810.04805)
36
- [SQuAD](https://arxiv.org/abs/1606.05250)
37
"""
38
"""
39
## Setup
40
"""
41
import os
42
import re
43
import json
44
import string
45
import numpy as np
46
import tensorflow as tf
47
from tensorflow import keras
48
from tensorflow.keras import layers
49
from tokenizers import BertWordPieceTokenizer
50
from transformers import BertTokenizer, TFBertModel, BertConfig
51
52
max_len = 384
53
configuration = BertConfig() # default parameters and configuration for BERT
54
55
"""
56
## Set-up BERT tokenizer
57
"""
58
# Save the slow pretrained tokenizer
59
slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
60
save_path = "bert_base_uncased/"
61
if not os.path.exists(save_path):
62
os.makedirs(save_path)
63
slow_tokenizer.save_pretrained(save_path)
64
65
# Load the fast tokenizer from saved file
66
tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)
67
68
"""
69
## Load the data
70
"""
71
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
72
train_path = keras.utils.get_file("train.json", train_data_url)
73
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
74
eval_path = keras.utils.get_file("eval.json", eval_data_url)
75
76
"""
77
## Preprocess the data
78
79
1. Go through the JSON file and store every record as a `SquadExample` object.
80
2. Go through each `SquadExample` and create `x_train, y_train, x_eval, y_eval`.
81
"""
82
83
84
class SquadExample:
85
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
86
self.question = question
87
self.context = context
88
self.start_char_idx = start_char_idx
89
self.answer_text = answer_text
90
self.all_answers = all_answers
91
self.skip = False
92
93
def preprocess(self):
94
context = self.context
95
question = self.question
96
answer_text = self.answer_text
97
start_char_idx = self.start_char_idx
98
99
# Clean context, answer and question
100
context = " ".join(str(context).split())
101
question = " ".join(str(question).split())
102
answer = " ".join(str(answer_text).split())
103
104
# Find end character index of answer in context
105
end_char_idx = start_char_idx + len(answer)
106
if end_char_idx >= len(context):
107
self.skip = True
108
return
109
110
# Mark the character indexes in context that are in answer
111
is_char_in_ans = [0] * len(context)
112
for idx in range(start_char_idx, end_char_idx):
113
is_char_in_ans[idx] = 1
114
115
# Tokenize context
116
tokenized_context = tokenizer.encode(context)
117
118
# Find tokens that were created from answer characters
119
ans_token_idx = []
120
for idx, (start, end) in enumerate(tokenized_context.offsets):
121
if sum(is_char_in_ans[start:end]) > 0:
122
ans_token_idx.append(idx)
123
124
if len(ans_token_idx) == 0:
125
self.skip = True
126
return
127
128
# Find start and end token index for tokens from answer
129
start_token_idx = ans_token_idx[0]
130
end_token_idx = ans_token_idx[-1]
131
132
# Tokenize question
133
tokenized_question = tokenizer.encode(question)
134
135
# Create inputs
136
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
137
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
138
tokenized_question.ids[1:]
139
)
140
attention_mask = [1] * len(input_ids)
141
142
# Pad and create attention masks.
143
# Skip if truncation is needed
144
padding_length = max_len - len(input_ids)
145
if padding_length > 0: # pad
146
input_ids = input_ids + ([0] * padding_length)
147
attention_mask = attention_mask + ([0] * padding_length)
148
token_type_ids = token_type_ids + ([0] * padding_length)
149
elif padding_length < 0: # skip
150
self.skip = True
151
return
152
153
self.input_ids = input_ids
154
self.token_type_ids = token_type_ids
155
self.attention_mask = attention_mask
156
self.start_token_idx = start_token_idx
157
self.end_token_idx = end_token_idx
158
self.context_token_to_char = tokenized_context.offsets
159
160
161
with open(train_path) as f:
162
raw_train_data = json.load(f)
163
164
with open(eval_path) as f:
165
raw_eval_data = json.load(f)
166
167
168
def create_squad_examples(raw_data):
169
squad_examples = []
170
for item in raw_data["data"]:
171
for para in item["paragraphs"]:
172
context = para["context"]
173
for qa in para["qas"]:
174
question = qa["question"]
175
answer_text = qa["answers"][0]["text"]
176
all_answers = [_["text"] for _ in qa["answers"]]
177
start_char_idx = qa["answers"][0]["answer_start"]
178
squad_eg = SquadExample(
179
question, context, start_char_idx, answer_text, all_answers
180
)
181
squad_eg.preprocess()
182
squad_examples.append(squad_eg)
183
return squad_examples
184
185
186
def create_inputs_targets(squad_examples):
187
dataset_dict = {
188
"input_ids": [],
189
"token_type_ids": [],
190
"attention_mask": [],
191
"start_token_idx": [],
192
"end_token_idx": [],
193
}
194
for item in squad_examples:
195
if item.skip == False:
196
for key in dataset_dict:
197
dataset_dict[key].append(getattr(item, key))
198
for key in dataset_dict:
199
dataset_dict[key] = np.array(dataset_dict[key])
200
201
x = [
202
dataset_dict["input_ids"],
203
dataset_dict["token_type_ids"],
204
dataset_dict["attention_mask"],
205
]
206
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
207
return x, y
208
209
210
train_squad_examples = create_squad_examples(raw_train_data)
211
x_train, y_train = create_inputs_targets(train_squad_examples)
212
print(f"{len(train_squad_examples)} training points created.")
213
214
eval_squad_examples = create_squad_examples(raw_eval_data)
215
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
216
print(f"{len(eval_squad_examples)} evaluation points created.")
217
218
"""
219
Create the Question-Answering Model using BERT and Functional API
220
"""
221
222
223
def create_model():
224
## BERT encoder
225
encoder = TFBertModel.from_pretrained("bert-base-uncased")
226
227
## QA Model
228
input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
229
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
230
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
231
embedding = encoder(
232
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
233
)[0]
234
235
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
236
start_logits = layers.Flatten()(start_logits)
237
238
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
239
end_logits = layers.Flatten()(end_logits)
240
241
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
242
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
243
244
model = keras.Model(
245
inputs=[input_ids, token_type_ids, attention_mask],
246
outputs=[start_probs, end_probs],
247
)
248
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
249
optimizer = keras.optimizers.Adam(lr=5e-5)
250
model.compile(optimizer=optimizer, loss=[loss, loss])
251
return model
252
253
254
"""
255
This code should preferably be run on Google Colab TPU runtime.
256
With Colab TPUs, each epoch will take 5-6 minutes.
257
"""
258
use_tpu = True
259
if use_tpu:
260
# Create distribution strategy
261
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
262
strategy = tf.distribute.TPUStrategy(tpu)
263
264
# Create model
265
with strategy.scope():
266
model = create_model()
267
else:
268
model = create_model()
269
270
model.summary()
271
272
"""
273
## Create evaluation Callback
274
275
This callback will compute the exact match score using the validation data
276
after every epoch.
277
"""
278
279
280
def normalize_text(text):
281
text = text.lower()
282
283
# Remove punctuations
284
exclude = set(string.punctuation)
285
text = "".join(ch for ch in text if ch not in exclude)
286
287
# Remove articles
288
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
289
text = re.sub(regex, " ", text)
290
291
# Remove extra white space
292
text = " ".join(text.split())
293
return text
294
295
296
class ExactMatch(keras.callbacks.Callback):
297
"""
298
Each `SquadExample` object contains the character level offsets for each token
299
in its input paragraph. We use them to get back the span of text corresponding
300
to the tokens between our predicted start and end tokens.
301
All the ground-truth answers are also present in each `SquadExample` object.
302
We calculate the percentage of data points where the span of text obtained
303
from model predictions matches one of the ground-truth answers.
304
"""
305
306
def __init__(self, x_eval, y_eval):
307
self.x_eval = x_eval
308
self.y_eval = y_eval
309
310
def on_epoch_end(self, epoch, logs=None):
311
pred_start, pred_end = self.model.predict(self.x_eval)
312
count = 0
313
eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
314
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
315
squad_eg = eval_examples_no_skip[idx]
316
offsets = squad_eg.context_token_to_char
317
start = np.argmax(start)
318
end = np.argmax(end)
319
if start >= len(offsets):
320
continue
321
pred_char_start = offsets[start][0]
322
if end < len(offsets):
323
pred_char_end = offsets[end][1]
324
pred_ans = squad_eg.context[pred_char_start:pred_char_end]
325
else:
326
pred_ans = squad_eg.context[pred_char_start:]
327
328
normalized_pred_ans = normalize_text(pred_ans)
329
normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
330
if normalized_pred_ans in normalized_true_ans:
331
count += 1
332
acc = count / len(self.y_eval[0])
333
print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
334
335
336
"""
337
## Train and Evaluate
338
"""
339
exact_match_callback = ExactMatch(x_eval, y_eval)
340
model.fit(
341
x_train,
342
y_train,
343
epochs=1, # For demonstration, 3 epochs are recommended
344
verbose=2,
345
batch_size=64,
346
callbacks=[exact_match_callback],
347
)
348
349