Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/ner_transformers.py
3507 views
1
"""
2
Title: Named Entity Recognition using Transformers
3
Author: [Varun Singh](https://www.linkedin.com/in/varunsingh2/)
4
Date created: 2021/06/23
5
Last modified: 2024/04/05
6
Description: NER using the Transformers and data from CoNLL 2003 shared task.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
9
"""
10
11
"""
12
## Introduction
13
14
Named Entity Recognition (NER) is the process of identifying named entities in text.
15
Example of named entities are: "Person", "Location", "Organization", "Dates" etc. NER is
16
essentially a token classification task where every token is classified into one or more
17
predetermined categories.
18
19
In this exercise, we will train a simple Transformer based model to perform NER. We will
20
be using the data from CoNLL 2003 shared task. For more information about the dataset,
21
please visit [the dataset website](https://www.clips.uantwerpen.be/conll2003/ner/).
22
However, since obtaining this data requires an additional step of getting a free license, we will be using
23
HuggingFace's datasets library which contains a processed version of this dataset.
24
"""
25
26
"""
27
## Install the open source datasets library from HuggingFace
28
29
We also download the script used to evaluate NER models.
30
"""
31
32
"""shell
33
pip3 install datasets
34
wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
35
"""
36
37
import os
38
39
os.environ["KERAS_BACKEND"] = "tensorflow"
40
41
import keras
42
from keras import ops
43
import numpy as np
44
import tensorflow as tf
45
from keras import layers
46
from datasets import load_dataset
47
from collections import Counter
48
from conlleval import evaluate
49
50
"""
51
We will be using the transformer implementation from this fantastic
52
[example](https://keras.io/examples/nlp/text_classification_with_transformer/).
53
54
Let's start by defining a `TransformerBlock` layer:
55
"""
56
57
58
class TransformerBlock(layers.Layer):
59
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
60
super().__init__()
61
self.att = keras.layers.MultiHeadAttention(
62
num_heads=num_heads, key_dim=embed_dim
63
)
64
self.ffn = keras.Sequential(
65
[
66
keras.layers.Dense(ff_dim, activation="relu"),
67
keras.layers.Dense(embed_dim),
68
]
69
)
70
self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
71
self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
72
self.dropout1 = keras.layers.Dropout(rate)
73
self.dropout2 = keras.layers.Dropout(rate)
74
75
def call(self, inputs, training=False):
76
attn_output = self.att(inputs, inputs)
77
attn_output = self.dropout1(attn_output, training=training)
78
out1 = self.layernorm1(inputs + attn_output)
79
ffn_output = self.ffn(out1)
80
ffn_output = self.dropout2(ffn_output, training=training)
81
return self.layernorm2(out1 + ffn_output)
82
83
84
"""
85
Next, let's define a `TokenAndPositionEmbedding` layer:
86
"""
87
88
89
class TokenAndPositionEmbedding(layers.Layer):
90
def __init__(self, maxlen, vocab_size, embed_dim):
91
super().__init__()
92
self.token_emb = keras.layers.Embedding(
93
input_dim=vocab_size, output_dim=embed_dim
94
)
95
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
96
97
def call(self, inputs):
98
maxlen = ops.shape(inputs)[-1]
99
positions = ops.arange(start=0, stop=maxlen, step=1)
100
position_embeddings = self.pos_emb(positions)
101
token_embeddings = self.token_emb(inputs)
102
return token_embeddings + position_embeddings
103
104
105
"""
106
## Build the NER model class as a `keras.Model` subclass
107
"""
108
109
110
class NERModel(keras.Model):
111
def __init__(
112
self, num_tags, vocab_size, maxlen=128, embed_dim=32, num_heads=2, ff_dim=32
113
):
114
super().__init__()
115
self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
116
self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
117
self.dropout1 = layers.Dropout(0.1)
118
self.ff = layers.Dense(ff_dim, activation="relu")
119
self.dropout2 = layers.Dropout(0.1)
120
self.ff_final = layers.Dense(num_tags, activation="softmax")
121
122
def call(self, inputs, training=False):
123
x = self.embedding_layer(inputs)
124
x = self.transformer_block(x)
125
x = self.dropout1(x, training=training)
126
x = self.ff(x)
127
x = self.dropout2(x, training=training)
128
x = self.ff_final(x)
129
return x
130
131
132
"""
133
## Load the CoNLL 2003 dataset from the datasets library and process it
134
"""
135
136
conll_data = load_dataset("conll2003")
137
138
"""
139
We will export this data to a tab-separated file format which will be easy to read as a
140
`tf.data.Dataset` object.
141
"""
142
143
144
def export_to_file(export_file_path, data):
145
with open(export_file_path, "w") as f:
146
for record in data:
147
ner_tags = record["ner_tags"]
148
tokens = record["tokens"]
149
if len(tokens) > 0:
150
f.write(
151
str(len(tokens))
152
+ "\t"
153
+ "\t".join(tokens)
154
+ "\t"
155
+ "\t".join(map(str, ner_tags))
156
+ "\n"
157
)
158
159
160
os.mkdir("data")
161
export_to_file("./data/conll_train.txt", conll_data["train"])
162
export_to_file("./data/conll_val.txt", conll_data["validation"])
163
164
"""
165
## Make the NER label lookup table
166
167
NER labels are usually provided in IOB, IOB2 or IOBES formats. Checkout this link for
168
more information:
169
[Wikipedia](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging))
170
171
Note that we start our label numbering from 1 since 0 will be reserved for padding. We
172
have a total of 10 labels: 9 from the NER dataset and one for padding.
173
"""
174
175
176
def make_tag_lookup_table():
177
iob_labels = ["B", "I"]
178
ner_labels = ["PER", "ORG", "LOC", "MISC"]
179
all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]
180
all_labels = ["-".join([a, b]) for a, b in all_labels]
181
all_labels = ["[PAD]", "O"] + all_labels
182
return dict(zip(range(0, len(all_labels) + 1), all_labels))
183
184
185
mapping = make_tag_lookup_table()
186
print(mapping)
187
188
"""
189
Get a list of all tokens in the training dataset. This will be used to create the
190
vocabulary.
191
"""
192
193
all_tokens = sum(conll_data["train"]["tokens"], [])
194
all_tokens_array = np.array(list(map(str.lower, all_tokens)))
195
196
counter = Counter(all_tokens_array)
197
print(len(counter))
198
199
num_tags = len(mapping)
200
vocab_size = 20000
201
202
# We only take (vocab_size - 2) most commons words from the training data since
203
# the `StringLookup` class uses 2 additional tokens - one denoting an unknown
204
# token and another one denoting a masking token
205
vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]
206
207
# The StringLook class will convert tokens to token IDs
208
lookup_layer = keras.layers.StringLookup(vocabulary=vocabulary)
209
210
"""
211
Create 2 new `Dataset` objects from the training and validation data
212
"""
213
214
train_data = tf.data.TextLineDataset("./data/conll_train.txt")
215
val_data = tf.data.TextLineDataset("./data/conll_val.txt")
216
217
"""
218
Print out one line to make sure it looks good. The first record in the line is the number of tokens.
219
After that we will have all the tokens followed by all the ner tags.
220
"""
221
222
print(list(train_data.take(1).as_numpy_iterator()))
223
224
"""
225
We will be using the following map function to transform the data in the dataset:
226
"""
227
228
229
def map_record_to_training_data(record):
230
record = tf.strings.split(record, sep="\t")
231
length = tf.strings.to_number(record[0], out_type=tf.int32)
232
tokens = record[1 : length + 1]
233
tags = record[length + 1 :]
234
tags = tf.strings.to_number(tags, out_type=tf.int64)
235
tags += 1
236
return tokens, tags
237
238
239
def lowercase_and_convert_to_ids(tokens):
240
tokens = tf.strings.lower(tokens)
241
return lookup_layer(tokens)
242
243
244
# We use `padded_batch` here because each record in the dataset has a
245
# different length.
246
batch_size = 32
247
train_dataset = (
248
train_data.map(map_record_to_training_data)
249
.map(lambda x, y: (lowercase_and_convert_to_ids(x), y))
250
.padded_batch(batch_size)
251
)
252
val_dataset = (
253
val_data.map(map_record_to_training_data)
254
.map(lambda x, y: (lowercase_and_convert_to_ids(x), y))
255
.padded_batch(batch_size)
256
)
257
258
ner_model = NERModel(num_tags, vocab_size, embed_dim=32, num_heads=4, ff_dim=64)
259
260
"""
261
We will be using a custom loss function that will ignore the loss from padded tokens.
262
"""
263
264
265
class CustomNonPaddingTokenLoss(keras.losses.Loss):
266
def __init__(self, name="custom_ner_loss"):
267
super().__init__(name=name)
268
269
def call(self, y_true, y_pred):
270
loss_fn = keras.losses.SparseCategoricalCrossentropy(
271
from_logits=False, reduction=None
272
)
273
loss = loss_fn(y_true, y_pred)
274
mask = ops.cast((y_true > 0), dtype="float32")
275
loss = loss * mask
276
return ops.sum(loss) / ops.sum(mask)
277
278
279
loss = CustomNonPaddingTokenLoss()
280
281
"""
282
## Compile and fit the model
283
"""
284
285
tf.config.run_functions_eagerly(True)
286
ner_model.compile(optimizer="adam", loss=loss)
287
ner_model.fit(train_dataset, epochs=10)
288
289
290
def tokenize_and_convert_to_ids(text):
291
tokens = text.split()
292
return lowercase_and_convert_to_ids(tokens)
293
294
295
# Sample inference using the trained model
296
sample_input = tokenize_and_convert_to_ids(
297
"eu rejects german call to boycott british lamb"
298
)
299
sample_input = ops.reshape(sample_input, shape=[1, -1])
300
print(sample_input)
301
302
output = ner_model.predict(sample_input)
303
prediction = np.argmax(output, axis=-1)[0]
304
prediction = [mapping[i] for i in prediction]
305
306
# eu -> B-ORG, german -> B-MISC, british -> B-MISC
307
print(prediction)
308
309
"""
310
## Metrics calculation
311
312
Here is a function to calculate the metrics. The function calculates F1 score for the
313
overall NER dataset as well as individual scores for each NER tag.
314
"""
315
316
317
def calculate_metrics(dataset):
318
all_true_tag_ids, all_predicted_tag_ids = [], []
319
320
for x, y in dataset:
321
output = ner_model.predict(x, verbose=0)
322
predictions = ops.argmax(output, axis=-1)
323
predictions = ops.reshape(predictions, [-1])
324
325
true_tag_ids = ops.reshape(y, [-1])
326
327
mask = (true_tag_ids > 0) & (predictions > 0)
328
true_tag_ids = true_tag_ids[mask]
329
predicted_tag_ids = predictions[mask]
330
331
all_true_tag_ids.append(true_tag_ids)
332
all_predicted_tag_ids.append(predicted_tag_ids)
333
334
all_true_tag_ids = np.concatenate(all_true_tag_ids)
335
all_predicted_tag_ids = np.concatenate(all_predicted_tag_ids)
336
337
predicted_tags = [mapping[tag] for tag in all_predicted_tag_ids]
338
real_tags = [mapping[tag] for tag in all_true_tag_ids]
339
340
evaluate(real_tags, predicted_tags)
341
342
343
calculate_metrics(val_dataset)
344
345
"""
346
## Conclusions
347
348
In this exercise, we created a simple transformer based named entity recognition model.
349
We trained it on the CoNLL 2003 shared task data and got an overall F1 score of around 70%.
350
State of the art NER models fine-tuned on pretrained models such as BERT or ELECTRA can easily
351
get much higher F1 score -between 90-95% on this dataset owing to the inherent knowledge
352
of words as part of the pretraining process and the usage of subword tokenization.
353
354
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ner-with-transformers)
355
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ner_with_transformers)."""
356
357