Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/lstm_seq2seq.py
3507 views
1
"""
2
Title: Character-level recurrent sequence-to-sequence model
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2017/09/29
5
Last modified: 2023/11/22
6
Description: Character-level recurrent sequence-to-sequence model.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example demonstrates how to implement a basic character-level
14
recurrent sequence-to-sequence model. We apply it to translating
15
short English sentences into short French sentences,
16
character-by-character. Note that it is fairly unusual to
17
do character-level machine translation, as word-level
18
models are more common in this domain.
19
20
**Summary of the algorithm**
21
22
- We start with input sequences from a domain (e.g. English sentences)
23
and corresponding target sequences from another domain
24
(e.g. French sentences).
25
- An encoder LSTM turns input sequences to 2 state vectors
26
(we keep the last LSTM state and discard the outputs).
27
- A decoder LSTM is trained to turn the target sequences into
28
the same sequence but offset by one timestep in the future,
29
a training process called "teacher forcing" in this context.
30
It uses as initial state the state vectors from the encoder.
31
Effectively, the decoder learns to generate `targets[t+1...]`
32
given `targets[...t]`, conditioned on the input sequence.
33
- In inference mode, when we want to decode unknown input sequences, we:
34
- Encode the input sequence into state vectors
35
- Start with a target sequence of size 1
36
(just the start-of-sequence character)
37
- Feed the state vectors and 1-char target sequence
38
to the decoder to produce predictions for the next character
39
- Sample the next character using these predictions
40
(we simply use argmax).
41
- Append the sampled character to the target sequence
42
- Repeat until we generate the end-of-sequence character or we
43
hit the character limit.
44
"""
45
46
"""
47
## Setup
48
"""
49
50
import numpy as np
51
import keras
52
import os
53
from pathlib import Path
54
55
"""
56
## Download the data
57
"""
58
59
fpath = keras.utils.get_file(origin="http://www.manythings.org/anki/fra-eng.zip")
60
dirpath = Path(fpath).parent.absolute()
61
os.system(f"unzip -q {fpath} -d {dirpath}")
62
63
"""
64
## Configuration
65
"""
66
67
batch_size = 64 # Batch size for training.
68
epochs = 100 # Number of epochs to train for.
69
latent_dim = 256 # Latent dimensionality of the encoding space.
70
num_samples = 10000 # Number of samples to train on.
71
# Path to the data txt file on disk.
72
data_path = os.path.join(dirpath, "fra.txt")
73
74
"""
75
## Prepare the data
76
"""
77
78
# Vectorize the data.
79
input_texts = []
80
target_texts = []
81
input_characters = set()
82
target_characters = set()
83
with open(data_path, "r", encoding="utf-8") as f:
84
lines = f.read().split("\n")
85
for line in lines[: min(num_samples, len(lines) - 1)]:
86
input_text, target_text, _ = line.split("\t")
87
# We use "tab" as the "start sequence" character
88
# for the targets, and "\n" as "end sequence" character.
89
target_text = "\t" + target_text + "\n"
90
input_texts.append(input_text)
91
target_texts.append(target_text)
92
for char in input_text:
93
if char not in input_characters:
94
input_characters.add(char)
95
for char in target_text:
96
if char not in target_characters:
97
target_characters.add(char)
98
99
input_characters = sorted(list(input_characters))
100
target_characters = sorted(list(target_characters))
101
num_encoder_tokens = len(input_characters)
102
num_decoder_tokens = len(target_characters)
103
max_encoder_seq_length = max([len(txt) for txt in input_texts])
104
max_decoder_seq_length = max([len(txt) for txt in target_texts])
105
106
print("Number of samples:", len(input_texts))
107
print("Number of unique input tokens:", num_encoder_tokens)
108
print("Number of unique output tokens:", num_decoder_tokens)
109
print("Max sequence length for inputs:", max_encoder_seq_length)
110
print("Max sequence length for outputs:", max_decoder_seq_length)
111
112
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
113
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
114
115
encoder_input_data = np.zeros(
116
(len(input_texts), max_encoder_seq_length, num_encoder_tokens),
117
dtype="float32",
118
)
119
decoder_input_data = np.zeros(
120
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
121
dtype="float32",
122
)
123
decoder_target_data = np.zeros(
124
(len(input_texts), max_decoder_seq_length, num_decoder_tokens),
125
dtype="float32",
126
)
127
128
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
129
for t, char in enumerate(input_text):
130
encoder_input_data[i, t, input_token_index[char]] = 1.0
131
encoder_input_data[i, t + 1 :, input_token_index[" "]] = 1.0
132
for t, char in enumerate(target_text):
133
# decoder_target_data is ahead of decoder_input_data by one timestep
134
decoder_input_data[i, t, target_token_index[char]] = 1.0
135
if t > 0:
136
# decoder_target_data will be ahead by one timestep
137
# and will not include the start character.
138
decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
139
decoder_input_data[i, t + 1 :, target_token_index[" "]] = 1.0
140
decoder_target_data[i, t:, target_token_index[" "]] = 1.0
141
142
"""
143
## Build the model
144
"""
145
146
# Define an input sequence and process it.
147
encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
148
encoder = keras.layers.LSTM(latent_dim, return_state=True)
149
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
150
151
# We discard `encoder_outputs` and only keep the states.
152
encoder_states = [state_h, state_c]
153
154
# Set up the decoder, using `encoder_states` as initial state.
155
decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
156
157
# We set up our decoder to return full output sequences,
158
# and to return internal states as well. We don't use the
159
# return states in the training model, but we will use them in inference.
160
decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
161
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
162
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
163
decoder_outputs = decoder_dense(decoder_outputs)
164
165
# Define the model that will turn
166
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
167
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
168
169
"""
170
## Train the model
171
"""
172
173
model.compile(
174
optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"]
175
)
176
model.fit(
177
[encoder_input_data, decoder_input_data],
178
decoder_target_data,
179
batch_size=batch_size,
180
epochs=epochs,
181
validation_split=0.2,
182
)
183
# Save model
184
model.save("s2s_model.keras")
185
186
"""
187
## Run inference (sampling)
188
189
1. encode input and retrieve initial decoder state
190
2. run one step of decoder with this initial state
191
and a "start of sequence" token as target.
192
Output will be the next target token.
193
3. Repeat with the current target token and current states
194
"""
195
196
# Define sampling models
197
# Restore the model and construct the encoder and decoder.
198
model = keras.models.load_model("s2s_model.keras")
199
200
encoder_inputs = model.input[0] # input_1
201
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
202
encoder_states = [state_h_enc, state_c_enc]
203
encoder_model = keras.Model(encoder_inputs, encoder_states)
204
205
decoder_inputs = model.input[1] # input_2
206
decoder_state_input_h = keras.Input(shape=(latent_dim,))
207
decoder_state_input_c = keras.Input(shape=(latent_dim,))
208
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
209
decoder_lstm = model.layers[3]
210
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
211
decoder_inputs, initial_state=decoder_states_inputs
212
)
213
decoder_states = [state_h_dec, state_c_dec]
214
decoder_dense = model.layers[4]
215
decoder_outputs = decoder_dense(decoder_outputs)
216
decoder_model = keras.Model(
217
[decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
218
)
219
220
# Reverse-lookup token index to decode sequences back to
221
# something readable.
222
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
223
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
224
225
226
def decode_sequence(input_seq):
227
# Encode the input as state vectors.
228
states_value = encoder_model.predict(input_seq, verbose=0)
229
230
# Generate empty target sequence of length 1.
231
target_seq = np.zeros((1, 1, num_decoder_tokens))
232
# Populate the first character of target sequence with the start character.
233
target_seq[0, 0, target_token_index["\t"]] = 1.0
234
235
# Sampling loop for a batch of sequences
236
# (to simplify, here we assume a batch of size 1).
237
stop_condition = False
238
decoded_sentence = ""
239
while not stop_condition:
240
output_tokens, h, c = decoder_model.predict(
241
[target_seq] + states_value, verbose=0
242
)
243
244
# Sample a token
245
sampled_token_index = np.argmax(output_tokens[0, -1, :])
246
sampled_char = reverse_target_char_index[sampled_token_index]
247
decoded_sentence += sampled_char
248
249
# Exit condition: either hit max length
250
# or find stop character.
251
if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
252
stop_condition = True
253
254
# Update the target sequence (of length 1).
255
target_seq = np.zeros((1, 1, num_decoder_tokens))
256
target_seq[0, 0, sampled_token_index] = 1.0
257
258
# Update states
259
states_value = [h, c]
260
return decoded_sentence
261
262
263
"""
264
You can now generate decoded sentences as such:
265
"""
266
267
for seq_index in range(20):
268
# Take one sequence (part of the training set)
269
# for trying out decoding.
270
input_seq = encoder_input_data[seq_index : seq_index + 1]
271
decoded_sentence = decode_sequence(input_seq)
272
print("-")
273
print("Input sentence:", input_texts[seq_index])
274
print("Decoded sentence:", decoded_sentence)
275
276