Path: blob/master/examples/keras_recipes/memory_efficient_embeddings.py
3507 views
"""1Title: Memory-efficient embeddings for recommendation systems2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/02/154Last modified: 2023/11/155Description: Using compositional & mixed-dimension embeddings for memory-efficient recommendation models.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates two techniques for building memory-efficient recommendation models13by reducing the size of the embedding tables, without sacrificing model effectiveness:14151. [Quotient-remainder trick](https://arxiv.org/abs/1909.02107), by Hao-Jun Michael Shi et al.,16which reduces the number of embedding vectors to store, yet produces unique embedding17vector for each item without explicit definition.182. [Mixed Dimension embeddings](https://arxiv.org/abs/1909.11810), by Antonio Ginart et al.,19which stores embedding vectors with mixed dimensions, where less popular items have20reduced dimension embeddings.2122We use the [1M version of the Movielens dataset](https://grouplens.org/datasets/movielens/1m/).23The dataset includes around 1 million ratings from 6,000 users on 4,000 movies.24"""2526"""27## Setup28"""2930import os3132os.environ["KERAS_BACKEND"] = "tensorflow"3334from zipfile import ZipFile35from urllib.request import urlretrieve36import numpy as np37import pandas as pd38import tensorflow as tf39import keras40from keras import layers41from keras.layers import StringLookup42import matplotlib.pyplot as plt4344"""45## Prepare the data4647## Download and process data48"""4950urlretrieve("http://files.grouplens.org/datasets/movielens/ml-1m.zip", "movielens.zip")51ZipFile("movielens.zip", "r").extractall()5253ratings_data = pd.read_csv(54"ml-1m/ratings.dat",55sep="::",56names=["user_id", "movie_id", "rating", "unix_timestamp"],57)5859ratings_data["movie_id"] = ratings_data["movie_id"].apply(lambda x: f"movie_{x}")60ratings_data["user_id"] = ratings_data["user_id"].apply(lambda x: f"user_{x}")61ratings_data["rating"] = ratings_data["rating"].apply(lambda x: float(x))62del ratings_data["unix_timestamp"]6364print(f"Number of users: {len(ratings_data.user_id.unique())}")65print(f"Number of movies: {len(ratings_data.movie_id.unique())}")66print(f"Number of ratings: {len(ratings_data.index)}")6768"""69## Create train and eval data splits70"""7172random_selection = np.random.rand(len(ratings_data.index)) <= 0.8573train_data = ratings_data[random_selection]74eval_data = ratings_data[~random_selection]7576train_data.to_csv("train_data.csv", index=False, sep="|", header=False)77eval_data.to_csv("eval_data.csv", index=False, sep="|", header=False)78print(f"Train data split: {len(train_data.index)}")79print(f"Eval data split: {len(eval_data.index)}")80print("Train and eval data files are saved.")8182"""83## Define dataset metadata and hyperparameters84"""8586csv_header = list(ratings_data.columns)87user_vocabulary = list(ratings_data.user_id.unique())88movie_vocabulary = list(ratings_data.movie_id.unique())89target_feature_name = "rating"90learning_rate = 0.00191batch_size = 12892num_epochs = 393base_embedding_dim = 649495"""96## Train and evaluate the model97"""9899100def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=True):101return tf.data.experimental.make_csv_dataset(102csv_file_path,103batch_size=batch_size,104column_names=csv_header,105label_name=target_feature_name,106num_epochs=1,107header=False,108field_delim="|",109shuffle=shuffle,110)111112113def run_experiment(model):114# Compile the model.115model.compile(116optimizer=keras.optimizers.Adam(learning_rate),117loss=keras.losses.MeanSquaredError(),118metrics=[keras.metrics.MeanAbsoluteError(name="mae")],119)120# Read the training data.121train_dataset = get_dataset_from_csv("train_data.csv", batch_size)122# Read the test data.123eval_dataset = get_dataset_from_csv("eval_data.csv", batch_size, shuffle=False)124# Fit the model with the training data.125history = model.fit(126train_dataset,127epochs=num_epochs,128validation_data=eval_dataset,129)130return history131132133"""134## Experiment 1: baseline collaborative filtering model135136### Implement embedding encoder137"""138139140def embedding_encoder(vocabulary, embedding_dim, num_oov_indices=0, name=None):141return keras.Sequential(142[143StringLookup(144vocabulary=vocabulary, mask_token=None, num_oov_indices=num_oov_indices145),146layers.Embedding(147input_dim=len(vocabulary) + num_oov_indices, output_dim=embedding_dim148),149],150name=f"{name}_embedding" if name else None,151)152153154"""155### Implement the baseline model156"""157158159def create_baseline_model():160# Receive the user as an input.161user_input = layers.Input(name="user_id", shape=(), dtype=tf.string)162# Get user embedding.163user_embedding = embedding_encoder(164vocabulary=user_vocabulary, embedding_dim=base_embedding_dim, name="user"165)(user_input)166167# Receive the movie as an input.168movie_input = layers.Input(name="movie_id", shape=(), dtype=tf.string)169# Get embedding.170movie_embedding = embedding_encoder(171vocabulary=movie_vocabulary, embedding_dim=base_embedding_dim, name="movie"172)(movie_input)173174# Compute dot product similarity between user and movie embeddings.175logits = layers.Dot(axes=1, name="dot_similarity")(176[user_embedding, movie_embedding]177)178# Convert to rating scale.179prediction = keras.activations.sigmoid(logits) * 5180# Create the model.181model = keras.Model(182inputs=[user_input, movie_input], outputs=prediction, name="baseline_model"183)184return model185186187baseline_model = create_baseline_model()188baseline_model.summary()189190"""191Notice that the number of trainable parameters is 623,744192"""193194history = run_experiment(baseline_model)195196plt.plot(history.history["loss"])197plt.plot(history.history["val_loss"])198plt.title("model loss")199plt.ylabel("loss")200plt.xlabel("epoch")201plt.legend(["train", "eval"], loc="upper left")202plt.show()203204"""205## Experiment 2: memory-efficient model206"""207208"""209### Implement Quotient-Remainder embedding as a layer210211The Quotient-Remainder technique works as follows. For a set of vocabulary and embedding size212`embedding_dim`, instead of creating a `vocabulary_size X embedding_dim` embedding table,213we create *two* `num_buckets X embedding_dim` embedding tables, where `num_buckets`214is much smaller than `vocabulary_size`.215An embedding for a given item `index` is generated via the following steps:2162171. Compute the `quotient_index` as `index // num_buckets`.2182. Compute the `remainder_index` as `index % num_buckets`.2193. Lookup `quotient_embedding` from the first embedding table using `quotient_index`.2204. Lookup `remainder_embedding` from the second embedding table using `remainder_index`.2215. Return `quotient_embedding` * `remainder_embedding`.222223This technique not only reduces the number of embedding vectors needs to be stored and trained,224but also generates a *unique* embedding vector for each item of size `embedding_dim`.225Note that `q_embedding` and `r_embedding` can be combined using other operations,226like `Add` and `Concatenate`.227"""228229230class QREmbedding(keras.layers.Layer):231def __init__(self, vocabulary, embedding_dim, num_buckets, name=None):232super().__init__(name=name)233self.num_buckets = num_buckets234235self.index_lookup = StringLookup(236vocabulary=vocabulary, mask_token=None, num_oov_indices=0237)238self.q_embeddings = layers.Embedding(239num_buckets,240embedding_dim,241)242self.r_embeddings = layers.Embedding(243num_buckets,244embedding_dim,245)246247def call(self, inputs):248# Get the item index.249embedding_index = self.index_lookup(inputs)250# Get the quotient index.251quotient_index = tf.math.floordiv(embedding_index, self.num_buckets)252# Get the reminder index.253remainder_index = tf.math.floormod(embedding_index, self.num_buckets)254# Lookup the quotient_embedding using the quotient_index.255quotient_embedding = self.q_embeddings(quotient_index)256# Lookup the remainder_embedding using the remainder_index.257remainder_embedding = self.r_embeddings(remainder_index)258# Use multiplication as a combiner operation259return quotient_embedding * remainder_embedding260261262"""263### Implement Mixed Dimension embedding as a layer264265In the mixed dimension embedding technique, we train embedding vectors with full dimensions266for the frequently queried items, while train embedding vectors with *reduced dimensions*267for less frequent items, plus a *projection weights matrix* to bring low dimension embeddings268to the full dimensions.269270More precisely, we define *blocks* of items of similar frequencies. For each block,271a `block_vocab_size X block_embedding_dim` embedding table and `block_embedding_dim X full_embedding_dim`272projection weights matrix are created. Note that, if `block_embedding_dim` equals `full_embedding_dim`,273the projection weights matrix becomes an *identity* matrix. Embeddings for a given batch of item274`indices` are generated via the following steps:2752761. For each block, lookup the `block_embedding_dim` embedding vectors using `indices`, and277project them to the `full_embedding_dim`.2782. If an item index does not belong to a given block, an out-of-vocabulary embedding is returned.279Each block will return a `batch_size X full_embedding_dim` tensor.2803. A mask is applied to the embeddings returned from each block in order to convert the281out-of-vocabulary embeddings to vector of zeros. That is, for each item in the batch,282a single non-zero embedding vector is returned from the all block embeddings.2834. Embeddings retrieved from the blocks are combined using *sum* to produce the final284`batch_size X full_embedding_dim` tensor.285286"""287288289class MDEmbedding(keras.layers.Layer):290def __init__(291self, blocks_vocabulary, blocks_embedding_dims, base_embedding_dim, name=None292):293super().__init__(name=name)294self.num_blocks = len(blocks_vocabulary)295296# Create vocab to block lookup.297keys = []298values = []299for block_idx, block_vocab in enumerate(blocks_vocabulary):300keys.extend(block_vocab)301values.extend([block_idx] * len(block_vocab))302self.vocab_to_block = tf.lookup.StaticHashTable(303tf.lookup.KeyValueTensorInitializer(keys, values), default_value=-1304)305306self.block_embedding_encoders = []307self.block_embedding_projectors = []308309# Create block embedding encoders and projectors.310for idx in range(self.num_blocks):311vocabulary = blocks_vocabulary[idx]312embedding_dim = blocks_embedding_dims[idx]313block_embedding_encoder = embedding_encoder(314vocabulary, embedding_dim, num_oov_indices=1315)316self.block_embedding_encoders.append(block_embedding_encoder)317if embedding_dim == base_embedding_dim:318self.block_embedding_projectors.append(layers.Lambda(lambda x: x))319else:320self.block_embedding_projectors.append(321layers.Dense(units=base_embedding_dim)322)323324def call(self, inputs):325# Get block index for each input item.326block_indicies = self.vocab_to_block.lookup(inputs)327# Initialize output embeddings to zeros.328embeddings = tf.zeros(shape=(tf.shape(inputs)[0], base_embedding_dim))329# Generate embeddings from blocks.330for idx in range(self.num_blocks):331# Lookup embeddings from the current block.332block_embeddings = self.block_embedding_encoders[idx](inputs)333# Project embeddings to base_embedding_dim.334block_embeddings = self.block_embedding_projectors[idx](block_embeddings)335# Create a mask to filter out embeddings of items that do not belong to the current block.336mask = tf.expand_dims(tf.cast(block_indicies == idx, tf.dtypes.float32), 1)337# Set the embeddings for the items not belonging to the current block to zeros.338block_embeddings = block_embeddings * mask339# Add the block embeddings to the final embeddings.340embeddings += block_embeddings341342return embeddings343344345"""346### Implement the memory-efficient model347348In this experiment, we are going to use the **Quotient-Remainder** technique to reduce the349size of the user embeddings, and the **Mixed Dimension** technique to reduce the size of the350movie embeddings.351352While in the [paper](https://arxiv.org/abs/1909.11810), an alpha-power rule is used to determined353the dimensions of the embedding of each block, we simply set the number of blocks and the354dimensions of embeddings of each block based on the histogram visualization of movies popularity.355"""356357movie_frequencies = ratings_data["movie_id"].value_counts()358movie_frequencies.hist(bins=10)359360"""361You can see that we can group the movies into three blocks, and assign them 64, 32, and 16362embedding dimensions, respectively. Feel free to experiment with different number of blocks363and dimensions.364"""365366sorted_movie_vocabulary = list(movie_frequencies.keys())367368movie_blocks_vocabulary = [369sorted_movie_vocabulary[:400], # high popularity movies block370sorted_movie_vocabulary[400:1700], # normal popularity movies block371sorted_movie_vocabulary[1700:], # low popularity movies block372]373374movie_blocks_embedding_dims = [64, 32, 16]375376user_embedding_num_buckets = len(user_vocabulary) // 50377378379def create_memory_efficient_model():380# Take the user as an input.381user_input = layers.Input(name="user_id", shape=(), dtype="string")382# Get user embedding.383user_embedding = QREmbedding(384vocabulary=user_vocabulary,385embedding_dim=base_embedding_dim,386num_buckets=user_embedding_num_buckets,387name="user_embedding",388)(user_input)389390# Take the movie as an input.391movie_input = layers.Input(name="movie_id", shape=(), dtype="string")392# Get embedding.393movie_embedding = MDEmbedding(394blocks_vocabulary=movie_blocks_vocabulary,395blocks_embedding_dims=movie_blocks_embedding_dims,396base_embedding_dim=base_embedding_dim,397name="movie_embedding",398)(movie_input)399400# Compute dot product similarity between user and movie embeddings.401logits = layers.Dot(axes=1, name="dot_similarity")(402[user_embedding, movie_embedding]403)404# Convert to rating scale.405prediction = keras.activations.sigmoid(logits) * 5406# Create the model.407model = keras.Model(408inputs=[user_input, movie_input], outputs=prediction, name="baseline_model"409)410return model411412413memory_efficient_model = create_memory_efficient_model()414memory_efficient_model.summary()415416"""417Notice that the number of trainable parameters is 117,968, which is more than 5x less than418the number of parameters in the baseline model.419"""420421history = run_experiment(memory_efficient_model)422423plt.plot(history.history["loss"])424plt.plot(history.history["val_loss"])425plt.title("model loss")426plt.ylabel("loss")427plt.xlabel("epoch")428plt.legend(["train", "eval"], loc="upper left")429plt.show()430431432