Path: blob/master/examples/generative/gpt2_text_generation_with_keras_hub.py
3507 views
"""1Title: GPT2 Text Generation with KerasHub2Author: Chen Qian3Date created: 2023/04/174Last modified: 2024/04/125Description: Use KerasHub GPT2 model and `samplers` to do text generation.6Accelerator: GPU7"""89"""10In this tutorial, you will learn to use [KerasHub](https://keras.io/keras_hub/) to load a11pre-trained Large Language Model (LLM) - [GPT-2 model](https://openai.com/research/better-language-models)12(originally invented by OpenAI), finetune it to a specific text style, and13generate text based on users' input (also known as prompt). You will also learn14how GPT2 adapts quickly to non-English languages, such as Chinese.15"""1617"""18## Before we begin1920Colab offers different kinds of runtimes. Make sure to go to **Runtime ->21Change runtime type** and choose the GPU Hardware Accelerator runtime22(which should have >12G host RAM and ~15G GPU RAM) since you will finetune the23GPT-2 model. Running this tutorial on CPU runtime will take hours.24"""2526"""27## Install KerasHub, Choose Backend and Import Dependencies2829This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of30`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into31KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select32the backend of your choice. We select the JAX backend below.33"""3435"""shell36pip install git+https://github.com/keras-team/keras-hub.git -q37"""3839import os4041os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"4243import keras_hub44import keras45import tensorflow as tf46import time4748keras.mixed_precision.set_global_policy("mixed_float16")4950"""51## Introduction to Generative Large Language Models (LLMs)5253Large language models (LLMs) are a type of machine learning models that are54trained on a large corpus of text data to generate outputs for various natural55language processing (NLP) tasks, such as text generation, question answering,56and machine translation.5758Generative LLMs are typically based on deep learning neural networks, such as59the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by60Google researchers in 2017, and are trained on massive amounts of text data,61often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)62and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),63are trained with a large dataset from various data sources which allows them to64generate output for many tasks. The core of Generative LLMs is predicting the65next word in a sentence, often referred as **Causal LM Pretraining**. In this66way LLMs can generate coherent text based on user prompts. For a more67pedagogical discussion on language models, you can refer to the68[Stanford CS324 LLM class](https://stanford-cs324.github.io/winter2022/lectures/introduction/).69"""7071"""72## Introduction to KerasHub7374Large Language Models are complex to build and expensive to train from scratch.75Luckily there are pretrained LLMs available for use right away. [KerasHub](https://keras.io/keras_hub/)76provides a large number of pre-trained checkpoints that allow you to experiment77with SOTA models without needing to train them yourself.7879KerasHub is a natural language processing library that supports users through80their entire development cycle. KerasHub offers both pretrained models and81modularized building blocks, so developers could easily reuse pretrained models82or stack their own LLM.8384In a nutshell, for generative LLM, KerasHub offers:8586- Pretrained models with `generate()` method, e.g.,87`keras_hub.models.GPT2CausalLM` and `keras_hub.models.OPTCausalLM`.88- Sampler class that implements generation algorithms such as Top-K, Beam and89contrastive search. These samplers can be used to generate text with90custom models.91"""9293"""94## Load a pre-trained GPT-2 model and generate some text9596KerasHub provides a number of pre-trained models, such as [Google97Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)98and [GPT-2](https://openai.com/research/better-language-models). You can see99the list of models available in the [KerasHub repository](https://github.com/keras-team/keras-hub/tree/master/keras_hub/models).100101It's very easy to load the GPT-2 model as you can see below:102"""103104# To speed up training and generation, we use preprocessor of length 128105# instead of full length 1024.106preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(107"gpt2_base_en",108sequence_length=128,109)110gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(111"gpt2_base_en", preprocessor=preprocessor112)113114"""115Once the model is loaded, you can use it to generate some text right away. Run116the cells below to give it a try. It's as simple as calling a single function117*generate()*:118"""119120start = time.time()121122output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)123print("\nGPT-2 output:")124print(output)125126end = time.time()127print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")128129"""130Try another one:131"""132133start = time.time()134135output = gpt2_lm.generate("That Italian restaurant is", max_length=200)136print("\nGPT-2 output:")137print(output)138139end = time.time()140print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")141142"""143Notice how much faster the second call is. This is because the computational144graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and145re-used in the 2nd behind the scenes.146147The quality of the generated text looks OK, but we can improve it via148fine-tuning.149"""150151"""152## More on the GPT-2 model from KerasHub153154Next up, we will actually fine-tune the model to update its parameters, but155before we do, let's take a look at the full set of tools we have to for working156with for GPT2.157158The code of GPT2 can be found159[here](https://github.com/keras-team/keras-hub/blob/master/keras_hub/models/gpt2/).160Conceptually the `GPT2CausalLM` can be hierarchically broken down into several161modules in KerasHub, all of which have a *from_preset()* function that loads a162pretrained model:163164- `keras_hub.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a165[byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).166- `keras_hub.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2167causal LM training. It does the tokenization along with other preprocessing168works such as creating the label and appending the end token.169- `keras_hub.models.GPT2Backbone`: the GPT2 model, which is a stack of170`keras_hub.layers.TransformerDecoder`. This is usually just referred as171`GPT2`.172- `keras_hub.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the173output of `GPT2Backbone` by embedding matrix to generate logits over174vocab tokens.175"""176177"""178## Finetune on Reddit dataset179180Now you have the knowledge of the GPT-2 model from KerasHub, you can take one181step further to finetune the model so that it generates text in a specific182style, short or long, strict or casual. In this tutorial, we will use reddit183dataset for example.184"""185186import tensorflow_datasets as tfds187188reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)189190"""191Let's take a look inside sample data from the reddit TensorFlow Dataset. There192are two features:193194- **__document__**: text of the post.195- **__title__**: the title.196197"""198199for document, title in reddit_ds:200print(document.numpy())201print(title.numpy())202break203204"""205In our case, we are performing next word prediction in a language model, so we206only need the 'document' feature.207"""208209train_ds = (210reddit_ds.map(lambda document, _: document)211.batch(32)212.cache()213.prefetch(tf.data.AUTOTUNE)214)215216"""217Now you can finetune the model using the familiar *fit()* function. Note that218`preprocessor` will be automatically called inside `fit` method since219`GPT2CausalLM` is a `keras_hub.models.Task` instance.220221This step takes quite a bit of GPU memory and a long time if we were to train222it all the way to a fully trained state. Here we just use part of the dataset223for demo purposes.224"""225226train_ds = train_ds.take(500)227num_epochs = 1228229# Linearly decaying learning rate.230learning_rate = keras.optimizers.schedules.PolynomialDecay(2315e-5,232decay_steps=train_ds.cardinality() * num_epochs,233end_learning_rate=0.0,234)235loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)236gpt2_lm.compile(237optimizer=keras.optimizers.Adam(learning_rate),238loss=loss,239weighted_metrics=["accuracy"],240)241242gpt2_lm.fit(train_ds, epochs=num_epochs)243244"""245After fine-tuning is finished, you can again generate text using the same246*generate()* function. This time, the text will be closer to Reddit writing247style, and the generated length will be close to our preset length in the248training set.249"""250251start = time.time()252253output = gpt2_lm.generate("I like basketball", max_length=200)254print("\nGPT-2 output:")255print(output)256257end = time.time()258print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")259260"""261## Into the Sampling Method262263In KerasHub, we offer a few sampling methods, e.g., contrastive search,264Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but265you can choose your own sampling method.266267Much like optimizer and activations, there are two ways to specify your custom268sampler:269270- Use a string identifier, such as "greedy", you are using the default271configuration via this way.272- Pass a `keras_hub.samplers.Sampler` instance, you can use custom configuration273via this way.274"""275276# Use a string identifier.277gpt2_lm.compile(sampler="top_k")278output = gpt2_lm.generate("I like basketball", max_length=200)279print("\nGPT-2 output:")280print(output)281282# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,283greedy_sampler = keras_hub.samplers.GreedySampler()284gpt2_lm.compile(sampler=greedy_sampler)285286output = gpt2_lm.generate("I like basketball", max_length=200)287print("\nGPT-2 output:")288print(output)289290"""291For more details on KerasHub `Sampler` class, you can check the code292[here](https://github.com/keras-team/keras-hub/tree/master/keras_hub/samplers).293"""294295"""296## Finetune on Chinese Poem Dataset297298We can also finetune GPT2 on non-English datasets. For readers knowing Chinese,299this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our300model to become a poet!301302Because GPT2 uses byte-pair encoder, and the original pretraining dataset303contains some Chinese characters, we can use the original vocab to finetune on304Chinese dataset.305"""306307"""shell308# Load chinese poetry dataset.309git clone https://github.com/chinese-poetry/chinese-poetry.git310"""311312"""313Load text from the json file. We only use《全唐诗》for demo purposes.314"""315316import os317import json318319poem_collection = []320for file in os.listdir("chinese-poetry/全唐诗"):321if ".json" not in file or "poet" not in file:322continue323full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)324with open(full_filename, "r") as f:325content = json.load(f)326poem_collection.extend(content)327328paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]329330"""331Let's take a look at sample data.332"""333334print(paragraphs[0])335336"""337Similar as Reddit example, we convert to TF dataset, and only use partial data338to train.339"""340341train_ds = (342tf.data.Dataset.from_tensor_slices(paragraphs)343.batch(16)344.cache()345.prefetch(tf.data.AUTOTUNE)346)347348# Running through the whole dataset takes long, only take `500` and run 1349# epochs for demo purposes.350train_ds = train_ds.take(500)351num_epochs = 1352353learning_rate = keras.optimizers.schedules.PolynomialDecay(3545e-4,355decay_steps=train_ds.cardinality() * num_epochs,356end_learning_rate=0.0,357)358loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)359gpt2_lm.compile(360optimizer=keras.optimizers.Adam(learning_rate),361loss=loss,362weighted_metrics=["accuracy"],363)364365gpt2_lm.fit(train_ds, epochs=num_epochs)366367"""368Let's check the result!369"""370371output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)372print(output)373374"""375Not bad 😀376"""377378379