Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/addition_rnn.py
3507 views
1
"""
2
Title: Sequence to sequence learning for performing number addition
3
Author: [Smerity](https://twitter.com/Smerity) and others
4
Date created: 2015/08/17
5
Last modified: 2024/02/13
6
Description: A model that learns to add strings of numbers, e.g. "535+61" -> "596".
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we train a model to learn to add two numbers, provided as strings.
14
15
**Example:**
16
17
- Input: "535+61"
18
- Output: "596"
19
20
Input may optionally be reversed, which was shown to increase performance in many tasks
21
in: [Learning to Execute](http://arxiv.org/abs/1410.4615) and
22
[Sequence to Sequence Learning with Neural Networks](http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf).
23
24
Theoretically, sequence order inversion introduces shorter term dependencies between
25
source and target for this problem.
26
27
**Results:**
28
29
For two digits (reversed):
30
31
+ One layer LSTM (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
32
33
Three digits (reversed):
34
35
+ One layer LSTM (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
36
37
Four digits (reversed):
38
39
+ One layer LSTM (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
40
41
Five digits (reversed):
42
43
+ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
44
"""
45
46
"""
47
## Setup
48
"""
49
50
import keras
51
from keras import layers
52
import numpy as np
53
54
# Parameters for the model and dataset.
55
TRAINING_SIZE = 50000
56
DIGITS = 3
57
REVERSE = True
58
59
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
60
# int is DIGITS.
61
MAXLEN = DIGITS + 1 + DIGITS
62
63
"""
64
## Generate the data
65
"""
66
67
68
class CharacterTable:
69
"""Given a set of characters:
70
+ Encode them to a one-hot integer representation
71
+ Decode the one-hot or integer representation to their character output
72
+ Decode a vector of probabilities to their character output
73
"""
74
75
def __init__(self, chars):
76
"""Initialize character table.
77
# Arguments
78
chars: Characters that can appear in the input.
79
"""
80
self.chars = sorted(set(chars))
81
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
82
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
83
84
def encode(self, C, num_rows):
85
"""One-hot encode given string C.
86
# Arguments
87
C: string, to be encoded.
88
num_rows: Number of rows in the returned one-hot encoding. This is
89
used to keep the # of rows for each data the same.
90
"""
91
x = np.zeros((num_rows, len(self.chars)))
92
for i, c in enumerate(C):
93
x[i, self.char_indices[c]] = 1
94
return x
95
96
def decode(self, x, calc_argmax=True):
97
"""Decode the given vector or 2D array to their character output.
98
# Arguments
99
x: A vector or a 2D array of probabilities or one-hot representations;
100
or a vector of character indices (used with `calc_argmax=False`).
101
calc_argmax: Whether to find the character index with maximum
102
probability, defaults to `True`.
103
"""
104
if calc_argmax:
105
x = x.argmax(axis=-1)
106
return "".join(self.indices_char[x] for x in x)
107
108
109
# All the numbers, plus sign and space for padding.
110
chars = "0123456789+ "
111
ctable = CharacterTable(chars)
112
113
questions = []
114
expected = []
115
seen = set()
116
print("Generating data...")
117
while len(questions) < TRAINING_SIZE:
118
f = lambda: int(
119
"".join(
120
np.random.choice(list("0123456789"))
121
for i in range(np.random.randint(1, DIGITS + 1))
122
)
123
)
124
a, b = f(), f()
125
# Skip any addition questions we've already seen
126
# Also skip any such that x+Y == Y+x (hence the sorting).
127
key = tuple(sorted((a, b)))
128
if key in seen:
129
continue
130
seen.add(key)
131
# Pad the data with spaces such that it is always MAXLEN.
132
q = "{}+{}".format(a, b)
133
query = q + " " * (MAXLEN - len(q))
134
ans = str(a + b)
135
# Answers can be of maximum size DIGITS + 1.
136
ans += " " * (DIGITS + 1 - len(ans))
137
if REVERSE:
138
# Reverse the query, e.g., '12+345 ' becomes ' 543+21'. (Note the
139
# space used for padding.)
140
query = query[::-1]
141
questions.append(query)
142
expected.append(ans)
143
print("Total questions:", len(questions))
144
145
"""
146
## Vectorize the data
147
"""
148
149
print("Vectorization...")
150
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=bool)
151
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=bool)
152
for i, sentence in enumerate(questions):
153
x[i] = ctable.encode(sentence, MAXLEN)
154
for i, sentence in enumerate(expected):
155
y[i] = ctable.encode(sentence, DIGITS + 1)
156
157
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
158
# digits.
159
indices = np.arange(len(y))
160
np.random.shuffle(indices)
161
x = x[indices]
162
y = y[indices]
163
164
# Explicitly set apart 10% for validation data that we never train over.
165
split_at = len(x) - len(x) // 10
166
(x_train, x_val) = x[:split_at], x[split_at:]
167
(y_train, y_val) = y[:split_at], y[split_at:]
168
169
print("Training Data:")
170
print(x_train.shape)
171
print(y_train.shape)
172
173
print("Validation Data:")
174
print(x_val.shape)
175
print(y_val.shape)
176
177
"""
178
## Build the model
179
"""
180
181
print("Build model...")
182
num_layers = 1 # Try to add more LSTM layers!
183
184
model = keras.Sequential()
185
# "Encode" the input sequence using a LSTM, producing an output of size 128.
186
# Note: In a situation where your input sequences have a variable length,
187
# use input_shape=(None, num_feature).
188
model.add(layers.Input((MAXLEN, len(chars))))
189
model.add(layers.LSTM(128))
190
# As the decoder RNN's input, repeatedly provide with the last output of
191
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
192
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
193
model.add(layers.RepeatVector(DIGITS + 1))
194
# The decoder RNN could be multiple layers stacked or a single layer.
195
for _ in range(num_layers):
196
# By setting return_sequences to True, return not only the last output but
197
# all the outputs so far in the form of (num_samples, timesteps,
198
# output_dim). This is necessary as TimeDistributed in the below expects
199
# the first dimension to be the timesteps.
200
model.add(layers.LSTM(128, return_sequences=True))
201
202
# Apply a dense layer to the every temporal slice of an input. For each of step
203
# of the output sequence, decide which character should be chosen.
204
model.add(layers.Dense(len(chars), activation="softmax"))
205
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
206
model.summary()
207
208
"""
209
## Train the model
210
"""
211
212
# Training parameters.
213
epochs = 30
214
batch_size = 32
215
216
# Formatting characters for results display.
217
green_color = "\033[92m"
218
red_color = "\033[91m"
219
end_char = "\033[0m"
220
221
# Train the model each generation and show predictions against the validation
222
# dataset.
223
for epoch in range(1, epochs):
224
print()
225
print("Iteration", epoch)
226
model.fit(
227
x_train,
228
y_train,
229
batch_size=batch_size,
230
epochs=1,
231
validation_data=(x_val, y_val),
232
)
233
# Select 10 samples from the validation set at random so we can visualize
234
# errors.
235
for i in range(10):
236
ind = np.random.randint(0, len(x_val))
237
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
238
preds = np.argmax(model.predict(rowx, verbose=0), axis=-1)
239
q = ctable.decode(rowx[0])
240
correct = ctable.decode(rowy[0])
241
guess = ctable.decode(preds[0], calc_argmax=False)
242
print("Q", q[::-1] if REVERSE else q, end=" ")
243
print("T", correct, end=" ")
244
if correct == guess:
245
print(f"{green_color}{guess}{end_char}")
246
else:
247
print(f"{red_color}{guess}{end_char}")
248
249
"""
250
You'll get to 99+% validation accuracy after ~30 epochs.
251
"""
252
253