Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/bidirectional_lstm_imdb.py
3507 views
1
"""
2
Title: Bidirectional LSTM on IMDB
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/05/03
5
Last modified: 2020/05/03
6
Description: Train a 2-layer bidirectional LSTM on the IMDB movie review sentiment classification dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import numpy as np
15
import keras
16
from keras import layers
17
18
max_features = 20000 # Only consider the top 20k words
19
maxlen = 200 # Only consider the first 200 words of each movie review
20
21
"""
22
## Build the model
23
"""
24
25
# Input for variable-length sequences of integers
26
inputs = keras.Input(shape=(None,), dtype="int32")
27
# Embed each integer in a 128-dimensional vector
28
x = layers.Embedding(max_features, 128)(inputs)
29
# Add 2 bidirectional LSTMs
30
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
31
x = layers.Bidirectional(layers.LSTM(64))(x)
32
# Add a classifier
33
outputs = layers.Dense(1, activation="sigmoid")(x)
34
model = keras.Model(inputs, outputs)
35
model.summary()
36
37
"""
38
## Load the IMDB movie review sentiment data
39
"""
40
41
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(
42
num_words=max_features
43
)
44
print(len(x_train), "Training sequences")
45
print(len(x_val), "Validation sequences")
46
# Use pad_sequence to standardize sequence length:
47
# this will truncate sequences longer than 200 words and zero-pad sequences shorter than 200 words.
48
x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
49
x_val = keras.utils.pad_sequences(x_val, maxlen=maxlen)
50
51
"""
52
## Train and evaluate the model
53
54
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/bidirectional-lstm-imdb)
55
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/bidirectional_lstm_imdb).
56
"""
57
58
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
59
model.fit(x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val))
60
61