Path: blob/master/examples/nlp/data_parallel_training_with_keras_hub.py
3507 views
"""1Title: Data Parallel Training with KerasHub and tf.distribute2Author: Anshuman Mishra3Date created: 2023/07/074Last modified: 2023/07/075Description: Data Parallel training with KerasHub and tf.distribute.6Accelerator: GPU7"""89"""10## Introduction1112Distributed training is a technique used to train deep learning models on multiple devices13or machines simultaneously. It helps to reduce training time and allows for training larger14models with more data. KerasHub is a library that provides tools and utilities for natural15language processing tasks, including distributed training.1617In this tutorial, we will use KerasHub to train a BERT-based masked language model (MLM)18on the wikitext-2 dataset (a 2 million word dataset of wikipedia articles). The MLM task19involves predicting the masked words in a sentence, which helps the model learn contextual20representations of words.2122This guide focuses on data parallelism, in particular synchronous data parallelism, where23each accelerator (a GPU or TPU) holds a complete replica of the model, and sees a24different partial batch of the input data. Partial gradients are computed on each device,25aggregated, and used to compute a global gradient update.2627Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras28models on multiple GPUs, with minimal changes to your code, in the following two setups:2930- On multiple GPUs (typically 2 to 8) installed on a single machine (single host,31multi-device training). This is the most common setup for researchers and small-scale32industry workflows.33- On a cluster of many machines, each hosting one or multiple GPUs (multi-worker34distributed training). This is a good setup for large-scale industry workflows, e.g.35training high-resolution text summarization models on billion word datasets on 20-100 GPUs.36"""3738"""shell39pip install -q --upgrade keras-hub40pip install -q --upgrade keras # Upgrade to Keras 3.41"""4243"""44## Imports45"""4647import os4849os.environ["KERAS_BACKEND"] = "tensorflow"5051import tensorflow as tf52import keras53import keras_hub5455"""56Before we start any training, let's configure our single GPU to show up as two logical57devices.5859When you are training with two or more physical GPUs, this is totally uncessary. This60is just a trick to show real distributed training on the default colab GPU runtime,61which has only one GPU available.62"""6364"""shell65nvidia-smi --query-gpu=memory.total --format=csv,noheader66"""6768physical_devices = tf.config.list_physical_devices("GPU")69tf.config.set_logical_device_configuration(70physical_devices[0],71[72tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),73tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),74],75)7677logical_devices = tf.config.list_logical_devices("GPU")78logical_devices7980EPOCHS = 3818283"""84To do single-host, multi-device synchronous training with a Keras model, you would use85the `tf.distribute.MirroredStrategy` API. Here's how it works:8687- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you88want to use (by default the strategy will use all GPUs available).89- Use the strategy object to open a scope, and within this scope, create all the Keras90objects you need that contain variables. Typically, that means **creating & compiling the91model** inside the distribution scope.92- Train the model via `fit()` as usual.93"""94strategy = tf.distribute.MirroredStrategy()95print(f"Number of devices: {strategy.num_replicas_in_sync}")9697"""98Base batch size and learning rate99"""100base_batch_size = 32101base_learning_rate = 1e-4102103"""104Calculate scaled batch size and learning rate105106"""107scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync108scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync109110"""111Now, we need to download and preprocess the wikitext-2 dataset. This dataset will be112used for pretraining the BERT model. We will filter out short lines to ensure that the113data has enough context for training.114"""115116keras.utils.get_file(117origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",118extract=True,119)120wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")121122# Load wikitext-103 and filter out short lines.123wiki_train_ds = (124tf.data.TextLineDataset(125wiki_dir + "wiki.train.tokens",126)127.filter(lambda x: tf.strings.length(x) > 100)128.shuffle(buffer_size=500)129.batch(scaled_batch_size)130.cache()131.prefetch(tf.data.AUTOTUNE)132)133wiki_val_ds = (134tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens")135.filter(lambda x: tf.strings.length(x) > 100)136.shuffle(buffer_size=500)137.batch(scaled_batch_size)138.cache()139.prefetch(tf.data.AUTOTUNE)140)141wiki_test_ds = (142tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens")143.filter(lambda x: tf.strings.length(x) > 100)144.shuffle(buffer_size=500)145.batch(scaled_batch_size)146.cache()147.prefetch(tf.data.AUTOTUNE)148)149150"""151In the above code, we download the wikitext-2 dataset and extract it. Then, we define152three datasets: wiki_train_ds, wiki_val_ds, and wiki_test_ds. These datasets are153filtered to remove short lines and are batched for efficient training.154"""155156"""157It's a common practice to use a decayed learning rate in NLP training/tuning. We'll158use `PolynomialDecay` schedule here.159160"""161162total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS163lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(164initial_learning_rate=scaled_learning_rate,165decay_steps=total_training_steps,166end_learning_rate=0.0,167)168169170class PrintLR(tf.keras.callbacks.Callback):171def on_epoch_end(self, epoch, logs=None):172print(173f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}"174)175176177"""178Let's also make a callback to TensorBoard, this will enable visualization of different179metrics while we train the model in later part of this tutorial. We put all the callbacks180together as follows:181"""182callbacks = [183tf.keras.callbacks.TensorBoard(log_dir="./logs"),184PrintLR(),185]186187188print(tf.config.list_physical_devices("GPU"))189190191"""192With the datasets prepared, we now initialize and compile our model and optimizer within193the `strategy.scope()`:194"""195196with strategy.scope():197# Everything that creates variables should be under the strategy scope.198# In general this is only model construction & `compile()`.199model_dist = keras_hub.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")200201# This line just sets pooled_dense layer as non-trainiable, we do this to avoid202# warnings of this layer being unused203model_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = False204205model_dist.compile(206loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),207optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),208weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],209jit_compile=False,210)211212model_dist.fit(213wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks214)215216"""217After fitting our model under the scope, we evaluate it normally!218"""219220model_dist.evaluate(wiki_test_ds)221222"""223For distributed training across multiple machines (as opposed to training that only leverages224multiple devices on a single machine), there are two distribution strategies you225could use: `MultiWorkerMirroredStrategy` and `ParameterServerStrategy`:226227- `tf.distribute.MultiWorkerMirroredStrategy` implements a synchronous CPU/GPU228multi-worker solution to work with Keras-style model building and training loop,229using synchronous reduction of gradients across the replicas.230- `tf.distribute.experimental.ParameterServerStrategy` implements an asynchronous CPU/GPU231multi-worker solution, where the parameters are stored on parameter servers, and232workers update the gradients to parameter servers asynchronously.233234### Further reading2352361. [TensorFlow distributed training guide](https://www.tensorflow.org/guide/distributed_training)2372. [Tutorial on multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)2383. [MirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)2394. [MultiWorkerMirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)2405. [Distributed training in tf.keras with Weights & Biases](https://towardsdatascience.com/distributed-training-in-tf-keras-with-w-b-ccf021f9322e)241"""242243244