Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/data_parallel_training_with_keras_hub.py
3507 views
1
"""
2
Title: Data Parallel Training with KerasHub and tf.distribute
3
Author: Anshuman Mishra
4
Date created: 2023/07/07
5
Last modified: 2023/07/07
6
Description: Data Parallel training with KerasHub and tf.distribute.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Distributed training is a technique used to train deep learning models on multiple devices
14
or machines simultaneously. It helps to reduce training time and allows for training larger
15
models with more data. KerasHub is a library that provides tools and utilities for natural
16
language processing tasks, including distributed training.
17
18
In this tutorial, we will use KerasHub to train a BERT-based masked language model (MLM)
19
on the wikitext-2 dataset (a 2 million word dataset of wikipedia articles). The MLM task
20
involves predicting the masked words in a sentence, which helps the model learn contextual
21
representations of words.
22
23
This guide focuses on data parallelism, in particular synchronous data parallelism, where
24
each accelerator (a GPU or TPU) holds a complete replica of the model, and sees a
25
different partial batch of the input data. Partial gradients are computed on each device,
26
aggregated, and used to compute a global gradient update.
27
28
Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras
29
models on multiple GPUs, with minimal changes to your code, in the following two setups:
30
31
- On multiple GPUs (typically 2 to 8) installed on a single machine (single host,
32
multi-device training). This is the most common setup for researchers and small-scale
33
industry workflows.
34
- On a cluster of many machines, each hosting one or multiple GPUs (multi-worker
35
distributed training). This is a good setup for large-scale industry workflows, e.g.
36
training high-resolution text summarization models on billion word datasets on 20-100 GPUs.
37
"""
38
39
"""shell
40
pip install -q --upgrade keras-hub
41
pip install -q --upgrade keras # Upgrade to Keras 3.
42
"""
43
44
"""
45
## Imports
46
"""
47
48
import os
49
50
os.environ["KERAS_BACKEND"] = "tensorflow"
51
52
import tensorflow as tf
53
import keras
54
import keras_hub
55
56
"""
57
Before we start any training, let's configure our single GPU to show up as two logical
58
devices.
59
60
When you are training with two or more physical GPUs, this is totally uncessary. This
61
is just a trick to show real distributed training on the default colab GPU runtime,
62
which has only one GPU available.
63
"""
64
65
"""shell
66
nvidia-smi --query-gpu=memory.total --format=csv,noheader
67
"""
68
69
physical_devices = tf.config.list_physical_devices("GPU")
70
tf.config.set_logical_device_configuration(
71
physical_devices[0],
72
[
73
tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),
74
tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),
75
],
76
)
77
78
logical_devices = tf.config.list_logical_devices("GPU")
79
logical_devices
80
81
EPOCHS = 3
82
83
84
"""
85
To do single-host, multi-device synchronous training with a Keras model, you would use
86
the `tf.distribute.MirroredStrategy` API. Here's how it works:
87
88
- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you
89
want to use (by default the strategy will use all GPUs available).
90
- Use the strategy object to open a scope, and within this scope, create all the Keras
91
objects you need that contain variables. Typically, that means **creating & compiling the
92
model** inside the distribution scope.
93
- Train the model via `fit()` as usual.
94
"""
95
strategy = tf.distribute.MirroredStrategy()
96
print(f"Number of devices: {strategy.num_replicas_in_sync}")
97
98
"""
99
Base batch size and learning rate
100
"""
101
base_batch_size = 32
102
base_learning_rate = 1e-4
103
104
"""
105
Calculate scaled batch size and learning rate
106
107
"""
108
scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
109
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync
110
111
"""
112
Now, we need to download and preprocess the wikitext-2 dataset. This dataset will be
113
used for pretraining the BERT model. We will filter out short lines to ensure that the
114
data has enough context for training.
115
"""
116
117
keras.utils.get_file(
118
origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",
119
extract=True,
120
)
121
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")
122
123
# Load wikitext-103 and filter out short lines.
124
wiki_train_ds = (
125
tf.data.TextLineDataset(
126
wiki_dir + "wiki.train.tokens",
127
)
128
.filter(lambda x: tf.strings.length(x) > 100)
129
.shuffle(buffer_size=500)
130
.batch(scaled_batch_size)
131
.cache()
132
.prefetch(tf.data.AUTOTUNE)
133
)
134
wiki_val_ds = (
135
tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens")
136
.filter(lambda x: tf.strings.length(x) > 100)
137
.shuffle(buffer_size=500)
138
.batch(scaled_batch_size)
139
.cache()
140
.prefetch(tf.data.AUTOTUNE)
141
)
142
wiki_test_ds = (
143
tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens")
144
.filter(lambda x: tf.strings.length(x) > 100)
145
.shuffle(buffer_size=500)
146
.batch(scaled_batch_size)
147
.cache()
148
.prefetch(tf.data.AUTOTUNE)
149
)
150
151
"""
152
In the above code, we download the wikitext-2 dataset and extract it. Then, we define
153
three datasets: wiki_train_ds, wiki_val_ds, and wiki_test_ds. These datasets are
154
filtered to remove short lines and are batched for efficient training.
155
"""
156
157
"""
158
It's a common practice to use a decayed learning rate in NLP training/tuning. We'll
159
use `PolynomialDecay` schedule here.
160
161
"""
162
163
total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
164
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
165
initial_learning_rate=scaled_learning_rate,
166
decay_steps=total_training_steps,
167
end_learning_rate=0.0,
168
)
169
170
171
class PrintLR(tf.keras.callbacks.Callback):
172
def on_epoch_end(self, epoch, logs=None):
173
print(
174
f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}"
175
)
176
177
178
"""
179
Let's also make a callback to TensorBoard, this will enable visualization of different
180
metrics while we train the model in later part of this tutorial. We put all the callbacks
181
together as follows:
182
"""
183
callbacks = [
184
tf.keras.callbacks.TensorBoard(log_dir="./logs"),
185
PrintLR(),
186
]
187
188
189
print(tf.config.list_physical_devices("GPU"))
190
191
192
"""
193
With the datasets prepared, we now initialize and compile our model and optimizer within
194
the `strategy.scope()`:
195
"""
196
197
with strategy.scope():
198
# Everything that creates variables should be under the strategy scope.
199
# In general this is only model construction & `compile()`.
200
model_dist = keras_hub.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")
201
202
# This line just sets pooled_dense layer as non-trainiable, we do this to avoid
203
# warnings of this layer being unused
204
model_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = False
205
206
model_dist.compile(
207
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
208
optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),
209
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
210
jit_compile=False,
211
)
212
213
model_dist.fit(
214
wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks
215
)
216
217
"""
218
After fitting our model under the scope, we evaluate it normally!
219
"""
220
221
model_dist.evaluate(wiki_test_ds)
222
223
"""
224
For distributed training across multiple machines (as opposed to training that only leverages
225
multiple devices on a single machine), there are two distribution strategies you
226
could use: `MultiWorkerMirroredStrategy` and `ParameterServerStrategy`:
227
228
- `tf.distribute.MultiWorkerMirroredStrategy` implements a synchronous CPU/GPU
229
multi-worker solution to work with Keras-style model building and training loop,
230
using synchronous reduction of gradients across the replicas.
231
- `tf.distribute.experimental.ParameterServerStrategy` implements an asynchronous CPU/GPU
232
multi-worker solution, where the parameters are stored on parameter servers, and
233
workers update the gradients to parameter servers asynchronously.
234
235
### Further reading
236
237
1. [TensorFlow distributed training guide](https://www.tensorflow.org/guide/distributed_training)
238
2. [Tutorial on multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
239
3. [MirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
240
4. [MultiWorkerMirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
241
5. [Distributed training in tf.keras with Weights & Biases](https://towardsdatascience.com/distributed-training-in-tf-keras-with-w-b-ccf021f9322e)
242
"""
243
244