Path: blob/master/examples/keras_rs/deep_recommender.py
3507 views
"""1Title: Deep Recommenders2Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)3Date created: 2025/04/284Last modified: 2025/04/285Description: Building a deep retrieval model with multiple stacked layers.6Accelerator: GPU7"""89"""10## Introduction1112One of the great advantages of using Keras to build recommender models is the13freedom to build rich, flexible feature representations.1415The first step in doing so is preparing the features, as raw features will16usually not be immediately usable in a model.1718For example:1920- User and item IDs may be strings (titles, usernames) or large, non-contiguous21integers (database IDs).22- Item descriptions could be raw text.23- Interaction timestamps could be raw Unix timestamps.2425These need to be appropriately transformed in order to be useful in building26models:2728- User and item IDs have to be translated into embedding vectors,29high-dimensional numerical representations that are adjusted during training30to help the model predict its objective better.31- Raw text needs to be tokenized (split into smaller parts such as individual32words) and translated into embeddings.33- Numerical features need to be normalized so that their values lie in a small34interval around 0.3536Fortunately, the Keras37[`FeatureSpace`](/api/utils/feature_space/) utility makes this38preprocessing easy.3940In this tutorial, we are going to incorporate multiple features in our models.41These features will come from preprocessing the MovieLens dataset.4243In the44[basic retrieval](/keras_rs/examples/basic_retrieval/)45tutorial, the models consist of only an embedding layer. In this tutorial, we46add more dense layers to our models to increase their expressive power.4748In general, deeper models are capable of learning more complex patterns than49shallower models. For example, our user model incorporates user IDs and user50features such as age, gender and occupation. A shallow model (say, a single51embedding layer) may only be able to learn the simplest relationships between52those features and movies: a given user generally prefers horror movies to53comedies. To capture more complex relationships, such as user preferences54evolving with their age, we may need a deeper model with multiple stacked dense55layers.5657Of course, complex models also have their disadvantages. The first is58computational cost, as larger models require both more memory and more59computation to train and serve. The second is the requirement for more data. In60general, more training data is needed to take advantage of deeper models. With61more parameters, deep models might overfit or even simply memorize the training62examples instead of learning a function that can generalize. Finally, training63deeper models may be harder, and more care needs to be taken in choosing64settings like regularization and learning rate.6566Finding a good architecture for a real-world recommender system is a complex67art, requiring good intuition and careful hyperparameter tuning. For example,68factors such as the depth and width of the model, activation function, learning69rate, and optimizer can radically change the performance of the model. Modelling70choices are further complicated by the fact that good offline evaluation metrics71may not correspond to good online performance, and that the choice of what to72optimize for is often more critical than the choice of model itself.7374Nevertheless, effort put into building and fine-tuning larger models often pays75off. In this tutorial, we will illustrate how to build a deep retrieval model.76We'll do this by building progressively more complex models to see how this77affects model performance.78"""7980"""shell81pip install -q keras-rs82"""8384import os8586os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`8788import keras89import matplotlib.pyplot as plt90import tensorflow as tf # Needed for the dataset91import tensorflow_datasets as tfds9293import keras_rs9495"""96## The MovieLens dataset9798Let's first have a look at what features we can use from the MovieLens dataset.99"""100101# Ratings data with user and movie data.102ratings = tfds.load("movielens/100k-ratings", split="train")103# Features of all the available movies.104movies = tfds.load("movielens/100k-movies", split="train")105106"""107The ratings dataset returns a dictionary of movie id, user id, the assigned108rating, timestamp, movie information, and user information:109"""110111for data in ratings.take(1).as_numpy_iterator():112print(str(data).replace(", '", ",\n '"))113114"""115In the Movielens dataset, user IDs are integers (represented as strings)116starting at 1 and with no gap. Normally, you would need to create a lookup table117to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the118user id directly as an index in our model, in particular to lookup the user119embedding from the user embedding table. So we need do know the number of users.120"""121122USERS_COUNT = (123ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))124.reduce(tf.constant(0, tf.int32), tf.maximum)125.numpy()126)127128"""129The movies dataset contains the movie id, movie title, and the genres it belongs130to. Note that the genres are encoded with integer labels.131"""132133for data in movies.take(1).as_numpy_iterator():134print(str(data).replace(", '", ",\n '"))135136"""137In the Movielens dataset, movie IDs are integers (represented as strings)138starting at 1 and with no gap. Normally, you would need to create a lookup table139to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the140movie id directly as an index in our model, in particular to lookup the movie141embedding from the movie embedding table. So we need do know the number of142movies.143"""144145MOVIES_COUNT = movies.cardinality().numpy()146147"""148## Preprocessing the dataset149150### Normalizing continuous features151152Continuous features may need normalization so that they fall within an153acceptable range for the model. We will give two examples of such normalization.154155#### Discretization156157A common transformation is to turn a continuous feature into a number of158categorical features. This makes good sense if we have reasons to suspect that a159feature's effect is non-continuous.160161We need to decide on a number the buckets we will use for discretization. Then,162we will use the Keras `FeatureSpace` utility to automatically find the minimum163and maximum value, and divide that range by the number of buckets to perform the164discretization.165166In this example, we will discretize the user age.167"""168169AGE_BINS_COUNT = 10170user_age_feature = keras.utils.FeatureSpace.float_discretized(171num_bins=AGE_BINS_COUNT, output_mode="int"172)173174"""175#### Rescaling176177Often, we want continous features to be between 0 and 1, or between -1 and 1.178To achieve this, we can rescale features that have a different range.179180In this example, we will standardize the rating, which is a integer between 1181and 5, to be a float between 0 and 1. We need to rescale it and offset it.182"""183184user_rating_feature = keras.utils.FeatureSpace.float_rescaled(185scale=1.0 / 4.0, offset=-1.0 / 4.0186)187188"""189### Turning categorical features into embeddings190191A categorical feature is a feature that does not express a continuous quantity,192but rather takes on one of a set of fixed values.193194Most deep learning models express these feature by turning them into195high-dimensional vectors. During model training, the value of that vector is196adjusted to help the model predict its objective better.197198For example, suppose that our goal is to predict which user is going to watch199which movie. To do that, we represent each user and each movie by an embedding200vector. Initially, these embeddings will take on random values. During training,201we adjust them so that embeddings of users and the movies they watch end up202closer together.203204Taking raw categorical features and turning them into embeddings is normally a205two-step process:2061. First, we need to translate the raw values into a range of contiguous207integers, normally by building a mapping (called a "vocabulary") that maps208raw values to integers.2092. Second, we need to take these integers and turn them into embeddings.210"""211212"""213#### Defining categorical features214215We will use the Keras `FeatureSpace` utility for the first step. Its `adapt`216method automatically discovers the vocabulary for categorical features.217"""218219user_gender_feature = keras.utils.FeatureSpace.integer_categorical(220num_oov_indices=0, output_mode="int"221)222user_occupation_feature = keras.utils.FeatureSpace.integer_categorical(223num_oov_indices=0, output_mode="int"224)225226"""227#### Using feature crosses228229With crosses we can do feature interactions between multiple categorical230features. This can be powerful to express that the combination of features231represents a specific taste for movies.232233Note that the combination of multiple features can result into on a super large234feature space, that is why the crossing_dim parameter is important to limit the235output dimension of the cross feature.236237In this example, we will cross age and gender with the Keras `FeatureSpace`238utility.239"""240241USER_GENDER_CROSS_COUNT = 20242user_gender_age_cross = keras.utils.FeatureSpace.cross(243feature_names=("user_gender", "raw_user_age"),244crossing_dim=USER_GENDER_CROSS_COUNT,245output_mode="int",246)247248"""249### Processing text features250251We may also want to add text features to our model. Usually, things like product252descriptions are free form text, and we can hope that our model can learn to use253the information they contain to make better recommendations, especially in a254cold-start or long tail scenario.255256While the MovieLens dataset does not give us rich textual features, we can still257use movie titles. This may help us capture the fact that movies with very258similar titles are likely to belong to the same series.259260The first transformation we need to apply to text is tokenization (splitting261into constituent words or word-pieces), followed by vocabulary learning,262followed by an embedding.263264265The266[`keras.layers.TextVectorization`](/api/layers/preprocessing_layers/text/text_vectorization/)267layer can do the first two steps for us.268"""269270title_vectorizer = keras.layers.TextVectorization(271max_tokens=10_000, output_sequence_length=16, dtype="int32"272)273title_vectorizer.adapt(movies.map(lambda x: x["movie_title"]))274275"""276Let's try it out:277"""278279for data in movies.take(1).as_numpy_iterator():280print(title_vectorizer(data["movie_title"]))281282"""283Each title is translated into a sequence of tokens, one for each piece we've284tokenized.285286We can check the learned vocabulary to verify that the layer is using the287correct tokenization:288"""289290print(title_vectorizer.get_vocabulary()[40:50])291292"""293This looks correct, the layer is tokenizing titles into individual words. Later,294we will see how to embed this tokenized text. For now, we turn this vectorizer295into a Keras `FeatureSpace` feature.296"""297298title_feature = keras.utils.FeatureSpace.feature(299preprocessor=title_vectorizer, dtype="string", output_mode="float"300)301TITLE_TOKEN_COUNT = title_vectorizer.vocabulary_size()302303"""304### Putting the FeatureSpace features together305306We're now ready to assemble the features with preprocessors in a `FeatureSpace`307object. We're then using `adapt` to go through the dataset and learn what needs308to be learned, such as the vocabulary size for categorical features or the309minimum and maximum values for bucketized features.310"""311312feature_space = keras.utils.FeatureSpace(313features={314# Numerical features to discretize.315"raw_user_age": user_age_feature,316# Categorical features encoded as integers.317"user_gender": user_gender_feature,318"user_occupation_label": user_occupation_feature,319# Labels are ratings between 0 and 1.320"user_rating": user_rating_feature,321"movie_title": title_feature,322},323crosses=[user_gender_age_cross],324output_mode="dict",325)326327feature_space.adapt(ratings)328GENDERS_COUNT = feature_space.preprocessors["user_gender"].vocabulary_size()329OCCUPATIONS_COUNT = feature_space.preprocessors[330"user_occupation_label"331].vocabulary_size()332333"""334## Pre-building the candidate set335336Our model is going to based on a `Retrieval` layer, which can provides a set of337best candidates among to full set of candidates. To do this, the retrieval layer338needs to know all the candidates and their features. In this section, we339assemble the full set of movies with the associated features.340341### Extract raw candidate features342343First, we gather all the raw features from the dataset in lists. That is the344titles of the movies and the genres. Note that one or more genres are345associated with each movie, and the number of genres varies per movie.346"""347348movie_titles = [""] * (MOVIES_COUNT + 1)349movie_genres = [[]] * (MOVIES_COUNT + 1)350for x in movies.as_numpy_iterator():351movie_id = int(x["movie_id"])352movie_titles[movie_id] = x["movie_title"]353movie_genres[movie_id] = x["movie_genres"].tolist()354355"""356### Preprocess candidate features357358Genres are already in the form of category numbers starting at zero. However, we359do need to figure out two things:360- The maximum number of genres a single movie can have; this will determine the361dimension for this feature.362- The maximum value for the genre, which will give us the total number of genres363and determine the size of our embedding table for genres.364"""365366MAX_GENRES_PER_MOVIE = 0367max_genre_id = 0368for one_movie_genres in movie_genres:369MAX_GENRES_PER_MOVIE = max(MAX_GENRES_PER_MOVIE, len(one_movie_genres))370if one_movie_genres:371max_genre_id = max(max_genre_id, max(one_movie_genres))372373GENRES_COUNT = max_genre_id + 1374375"""376Now we need to pad genres with an Out Of Vocabulary value to be able to377represent genres as a fixed size vector. We'll pad with zeros for simplicity, so378we're adding one to the genres to not conflict with genre zero, which is a valid379genre.380"""381382movie_genres = [383[g + 1 for g in genres] + [0] * (MAX_GENRES_PER_MOVIE - len(genres))384for genres in movie_genres385]386387"""388Then, we vectorize all the movie titles.389"""390391movie_titles_vectors = title_vectorizer(movie_titles)392393"""394### Convert candidate set to native tensors395396We're now ready to combine these in a dataset. The last step is to make sure397everything is a native tensor that can be consumed by the retrieval layer.398As a remminder, movie id zero does not exist.399"""400401MOVIES_DATASET = {402"movie_id": keras.ops.arange(0, MOVIES_COUNT + 1, dtype="int32"),403"movie_title_vector": movie_titles_vectors,404"movie_genres": keras.ops.convert_to_tensor(movie_genres, dtype="int32"),405}406407"""408## Preparing the data409410We can now define our preprocessing function. Most features will be handled411by the `FeatureSpace`. User IDs and Movie IDs need to be extracted. Movie genres412need to be padded. Then everything is packaged as a tuple with a dict of input413features and a float for the rating, which is used as a label.414"""415416417def preprocess_rating(x):418features = feature_space(419{420"raw_user_age": x["raw_user_age"],421"user_gender": x["user_gender"],422"user_occupation_label": x["user_occupation_label"],423"user_rating": x["user_rating"],424"movie_title": x["movie_title"],425}426)427features = {k: tf.squeeze(v, axis=0) for k, v in features.items()}428movie_genres = x["movie_genres"]429430return (431{432# User inputs are user ID and user features433"user_id": int(x["user_id"]),434"raw_user_age": features["raw_user_age"],435"user_gender": features["user_gender"],436"user_occupation_label": features["user_occupation_label"],437"user_gender_X_raw_user_age": tf.squeeze(438features["user_gender_X_raw_user_age"], axis=-1439),440# Movie inputs are movie ID, vectorized title and genres441"movie_id": int(x["movie_id"]),442"movie_title_vector": features["movie_title"],443"movie_genres": tf.pad(444movie_genres + 1,445[[0, MAX_GENRES_PER_MOVIE - tf.shape(movie_genres)[0]]],446),447},448# Label is user rating between 0 and 1449features["user_rating"],450)451452453"""454We shuffle and then split the data into a training set and a testing set.455"""456457shuffled_ratings = ratings.map(preprocess_rating).shuffle(458100_000, seed=42, reshuffle_each_iteration=False459)460461train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()462test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()463464"""465## Model definition466467### Query model468469The query model is first tasked with converting user features to embeddings. The470embeddings are then concatenated into a single vector.471472Defining deeper models will require us to stack more layers on top of this first473set of embeddings. A progressively narrower stack of layers, separated by an474activation function, is a common pattern:475476```477+----------------------+478| 64 x 32 |479+----------------------+480| relu481+--------------------------+482| 128 x 64 |483+--------------------------+484| relu485+------------------------------+486| ... x 128 |487+------------------------------+488```489490Since the expressive power of deep linear models is no greater than that of491shallow linear models, we use ReLU activations for all but the last hidden492layer. The final hidden layer does not use any activation function: using an493activation function would limit the output space of the final embeddings and494might negatively impact the performance of the model. For instance, if ReLUs are495used in the projection layer, all components in the output embedding would be496non-negative.497498We're going to try this here. To make experimentation with different depths499easy, let's define a model whose depth (and width) is defined by a constructor500parameters. The `layer_sizes` parameter gives us the depth and width of the501model. We can vary it to experiment with shallower or deeper models.502"""503504505class QueryModel(keras.Model):506"""Model for encoding user queries."""507508def __init__(self, layer_sizes, embedding_dimension=32):509"""Construct a model for encoding user queries.510511Args:512layer_sizes: A list of integers where the i-th entry represents the513number of units the i-th layer contains.514embedding_dimension: Output dimension for all embedding tables.515"""516super().__init__()517518# We first generate embeddings.519self.user_embedding = keras.layers.Embedding(520# +1 for user ID zero, which does not exist521USERS_COUNT + 1,522embedding_dimension,523)524self.gender_embedding = keras.layers.Embedding(525GENDERS_COUNT, embedding_dimension526)527self.age_embedding = keras.layers.Embedding(AGE_BINS_COUNT, embedding_dimension)528self.gender_x_age_embedding = keras.layers.Embedding(529USER_GENDER_CROSS_COUNT, embedding_dimension530)531self.occupation_embedding = keras.layers.Embedding(532OCCUPATIONS_COUNT, embedding_dimension533)534535# Then construct the layers.536self.dense_layers = keras.Sequential()537538# Use the ReLU activation for all but the last layer.539for layer_size in layer_sizes[:-1]:540self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu"))541542# No activation for the last layer.543self.dense_layers.add(keras.layers.Dense(layer_sizes[-1]))544545def call(self, inputs):546# Take the inputs, pass each through its embedding layer, concatenate.547feature_embedding = keras.ops.concatenate(548[549self.user_embedding(inputs["user_id"]),550self.gender_embedding(inputs["user_gender"]),551self.age_embedding(inputs["raw_user_age"]),552self.gender_x_age_embedding(inputs["user_gender_X_raw_user_age"]),553self.occupation_embedding(inputs["user_occupation_label"]),554],555axis=1,556)557return self.dense_layers(feature_embedding)558559560"""561## Candidate model562563We can adopt the same approach for the candidate model. Again, we start with564converting movie features to embeddings, concatenate them and then expand it565with hidden layers:566"""567568569class CandidateModel(keras.Model):570"""Model for encoding candidates (movies)."""571572def __init__(self, layer_sizes, embedding_dimension=32):573"""Construct a model for encoding candidates (movies).574575Args:576layer_sizes: A list of integers where the i-th entry represents the577number of units the i-th layer contains.578embedding_dimension: Output dimension for all embedding tables.579"""580super().__init__()581582# We first generate embeddings.583self.movie_embedding = keras.layers.Embedding(584# +1 for movie ID zero, which does not exist585MOVIES_COUNT + 1,586embedding_dimension,587)588# Take all the title tokens for the title of the movie, embed each589# token, and then take the mean of all token embeddings.590self.movie_title_embedding = keras.Sequential(591[592keras.layers.Embedding(593# +1 for OOV token, which is used for padding594TITLE_TOKEN_COUNT + 1,595embedding_dimension,596mask_zero=True,597),598keras.layers.GlobalAveragePooling1D(),599]600)601# Take all the genres for the movie, embed each genre, and then take the602# mean of all genre embeddings.603self.movie_genres_embedding = keras.Sequential(604[605keras.layers.Embedding(606# +1 for OOV genre, which is used for padding607GENRES_COUNT + 1,608embedding_dimension,609mask_zero=True,610),611keras.layers.GlobalAveragePooling1D(),612]613)614615# Then construct the layers.616self.dense_layers = keras.Sequential()617618# Use the ReLU activation for all but the last layer.619for layer_size in layer_sizes[:-1]:620self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu"))621622# No activation for the last layer.623self.dense_layers.add(keras.layers.Dense(layer_sizes[-1]))624625def call(self, inputs):626movie_id = inputs["movie_id"]627movie_title_vector = inputs["movie_title_vector"]628movie_genres = inputs["movie_genres"]629feature_embedding = keras.ops.concatenate(630[631self.movie_embedding(movie_id),632self.movie_title_embedding(movie_title_vector),633self.movie_genres_embedding(movie_genres),634],635axis=1,636)637return self.dense_layers(feature_embedding)638639640"""641## Combined model642643With both QueryModel and CandidateModel defined, we can put together a combined644model and implement our loss and metrics logic. To make things simple, we'll645enforce that the model structure is the same across the query and candidate646models.647"""648649650class RetrievalModel(keras.Model):651"""Combined model."""652653def __init__(654self,655layer_sizes=(32,),656embedding_dimension=32,657retrieval_k=100,658):659"""Construct a combined model.660661Args:662layer_sizes: A list of integers where the i-th entry represents the663number of units the i-th layer contains.664embedding_dimension: Output dimension for all embedding tables.665retrieval_k: How many candidate movies to retrieve.666"""667super().__init__()668self.query_model = QueryModel(layer_sizes, embedding_dimension)669self.candidate_model = CandidateModel(layer_sizes, embedding_dimension)670self.retrieval = keras_rs.layers.BruteForceRetrieval(671k=retrieval_k, return_scores=False672)673self.update_candidates() # Provide an initial set of candidates674self.loss_fn = keras.losses.MeanSquaredError()675self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(676k=retrieval_k, from_sorted_ids=True677)678679def update_candidates(self):680self.retrieval.update_candidates(681self.candidate_model.predict(MOVIES_DATASET, verbose=0)682)683684def call(self, inputs, training=False):685query_embeddings = self.query_model(686{687"user_id": inputs["user_id"],688"raw_user_age": inputs["raw_user_age"],689"user_gender": inputs["user_gender"],690"user_occupation_label": inputs["user_occupation_label"],691"user_gender_X_raw_user_age": inputs["user_gender_X_raw_user_age"],692}693)694candidate_embeddings = self.candidate_model(695{696"movie_id": inputs["movie_id"],697"movie_title_vector": inputs["movie_title_vector"],698"movie_genres": inputs["movie_genres"],699}700)701702result = {703"query_embeddings": query_embeddings,704"candidate_embeddings": candidate_embeddings,705}706if not training:707# No need to spend time extracting top predicted movies during708# training, they are not used.709result["predictions"] = self.retrieval(query_embeddings)710return result711712def evaluate(713self,714x=None,715y=None,716batch_size=None,717verbose="auto",718sample_weight=None,719steps=None,720callbacks=None,721return_dict=False,722**kwargs,723):724"""Overridden to update the candidate set.725726Before evaluating the model, we need to update our retrieval layer by727re-computing the values predicted by the candidate model for all the728candidates.729"""730self.update_candidates()731return super().evaluate(732x,733y,734batch_size=batch_size,735verbose=verbose,736sample_weight=sample_weight,737steps=steps,738callbacks=callbacks,739return_dict=return_dict,740**kwargs,741)742743def compute_loss(self, x, y, y_pred, sample_weight, training=True):744query_embeddings = y_pred["query_embeddings"]745candidate_embeddings = y_pred["candidate_embeddings"]746747labels = keras.ops.expand_dims(y, -1)748# Compute the affinity score by multiplying the two embeddings.749scores = keras.ops.sum(750keras.ops.multiply(query_embeddings, candidate_embeddings),751axis=1,752keepdims=True,753)754return self.loss_fn(labels, scores, sample_weight)755756def compute_metrics(self, x, y, y_pred, sample_weight=None):757if "predictions" in y_pred:758# We are evaluating or predicting. Update `top_k_metric`.759movie_ids = x["movie_id"]760predictions = y_pred["predictions"]761# For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we762# only take top rated movies, and we put a weight of 0 for the rest.763rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32")764sample_weight = (765rating_weight766if sample_weight is None767else keras.ops.multiply(rating_weight, sample_weight)768)769self.top_k_metric.update_state(770movie_ids, predictions, sample_weight=sample_weight771)772return self.get_metrics_result()773else:774# We are training. `top_k_metric` is not updated and is zero, so775# don't report it.776result = self.get_metrics_result()777result.pop(self.top_k_metric.name)778return result779780781"""782## Training the model783784### Shallow model785786We're ready to try out our first, shallow, model!787"""788789NUM_EPOCHS = 30790791one_layer_model = RetrievalModel((32,))792one_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))793794one_layer_history = one_layer_model.fit(795train_ratings,796validation_data=test_ratings,797validation_freq=5,798epochs=NUM_EPOCHS,799)800801"""802This gives us a top-100 accuracy of around 0.30. We can use this as a reference803point for evaluating deeper models.804805### Deeper model806807What about a deeper model with two layers?808"""809810two_layer_model = RetrievalModel((64, 32))811two_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))812two_layer_history = two_layer_model.fit(813train_ratings,814validation_data=test_ratings,815validation_freq=5,816epochs=NUM_EPOCHS,817)818819"""820While the deeper model seems to learn a bit better than the shallow model at821first, the difference becomes minimal towards the end of the trainign. We can822plot the validation accuracy curves to illustrate this:823"""824825METRIC = "val_sparse_top_k_categorical_accuracy"826num_validation_runs = len(one_layer_history.history[METRIC])827epochs = [(x + 1) * 5 for x in range(num_validation_runs)]828829plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer")830plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers")831plt.title("Accuracy vs epoch")832plt.xlabel("epoch")833plt.ylabel("Top-100 accuracy")834plt.legend()835plt.show()836837"""838Deeper models are not necessarily better. The following model extends the depth839to three layers:840"""841842three_layer_model = RetrievalModel((128, 64, 32))843three_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05))844three_layer_history = three_layer_model.fit(845train_ratings,846validation_data=test_ratings,847validation_freq=5,848epochs=NUM_EPOCHS,849)850851"""852We don't really see an improvement over the shallow model:853"""854855plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer")856plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers")857plt.plot(epochs, three_layer_history.history[METRIC], label="3 layers")858plt.title("Accuracy vs epoch")859plt.xlabel("epoch")860plt.ylabel("Top-100 accuracy")861plt.legend()862plt.show()863864"""865This is a good illustration of the fact that deeper and larger models, while866capable of superior performance, often require very careful tuning. For example,867throughout this tutorial we used a single, fixed learning rate. Alternative868choices may give very different results and are worth exploring.869870With appropriate tuning and sufficient data, the effort put into building larger871and deeper models is in many cases well worth it: larger models can lead to872substantial improvements in prediction accuracy.873874## Next Steps875876In this tutorial we expanded our retrieval model with dense layers and877activation functions. To see how to create a model that can perform not only878retrieval tasks but also rating tasks, take a look at the multitask tutorial.879"""880881882