Path: blob/master/examples/structured_data/customer_lifetime_value.py
3507 views
"""1Title: Deep Learning for Customer Lifetime Value2Author: [Praveen Hosdrug](https://www.linkedin.com/in/praveenhosdrug/)3Date created: 2024/11/234Last modified: 2024/11/275Description: A hybrid deep learning architecture for predicting customer purchase patterns and lifetime value.6Accelerator: None7"""89"""10## Introduction1112A hybrid deep learning architecture combining Transformer encoders and LSTM networks13for predicting customer purchase patterns and lifetime value using transaction history.14While many existing review articles focus on classic parametric models and traditional machine learning algorithms15,this implementation leverages recent advancements in Transformer-based models for time series prediction.16The approach handles multi-granularity prediction across different temporal scales.1718"""1920"""21## Setting up Libraries for the Deep Learning Project22"""23import subprocess242526def install_packages(packages):27"""28Install a list of packages using pip.2930Args:31packages (list): A list of package names to install.32"""33for package in packages:34subprocess.run(["pip", "install", package], check=True)353637"""38## List of Packages to Install39401. uciml: For the purpose of the tutorial; we will be using41the UK Retail [Dataset](https://archive.ics.uci.edu/dataset/352/online+retail)422. keras_hub: Access to the transformer encoder layer.43"""4445packages_to_install = ["ucimlrepo", "keras_hub"]4647# Install the packages48install_packages(packages_to_install)4950# Core data processing and numerical libraries51import os5253os.environ["KERAS_BACKEND"] = "jax"54import keras55import numpy as np56import pandas as pd57from typing import Dict585960# Visualization61import matplotlib.pyplot as plt6263# Keras imports64from keras import layers65from keras import Model66from keras import ops67from keras_hub.layers import TransformerEncoder68from keras import regularizers6970# UK Retail Dataset71from ucimlrepo import fetch_ucirepo7273"""74## Preprocessing the UK Retail dataset75"""767778def prepare_time_series_data(data):79"""80Preprocess retail transaction data for deep learning.8182Args:83data: Raw transaction data containing InvoiceDate, UnitPrice, etc.84Returns:85Processed DataFrame with calculated features86"""87processed_data = data.copy()8889# Essential datetime handling for temporal ordering90processed_data["InvoiceDate"] = pd.to_datetime(processed_data["InvoiceDate"])9192# Basic business constraints and calculations93processed_data = processed_data[processed_data["UnitPrice"] > 0]94processed_data["Amount"] = processed_data["UnitPrice"] * processed_data["Quantity"]95processed_data["CustomerID"] = processed_data["CustomerID"].fillna(99999.0)9697# Handle outliers in Amount using statistical thresholds98q1 = processed_data["Amount"].quantile(0.25)99q3 = processed_data["Amount"].quantile(0.75)100101# Define bounds - using 1.5 IQR rule102lower_bound = q1 - 1.5 * (q3 - q1)103upper_bound = q3 + 1.5 * (q3 - q1)104105# Filter outliers106processed_data = processed_data[107(processed_data["Amount"] >= lower_bound)108& (processed_data["Amount"] <= upper_bound)109]110111return processed_data112113114# Load Data115116online_retail = fetch_ucirepo(id=352)117raw_data = online_retail.data.features118transformed_data = prepare_time_series_data(raw_data)119120121def prepare_data_for_modeling(122df: pd.DataFrame,123input_sequence_length: int = 6,124output_sequence_length: int = 6,125) -> Dict:126"""127Transform retail data into sequence-to-sequence format with separate128temporal and trend components.129"""130df = df.copy()131132# Daily aggregation133daily_purchases = (134df.groupby(["CustomerID", pd.Grouper(key="InvoiceDate", freq="D")])135.agg({"Amount": "sum", "Quantity": "sum", "Country": "first"})136.reset_index()137)138139daily_purchases["frequency"] = np.where(daily_purchases["Amount"] > 0, 1, 0)140141# Monthly resampling142monthly_purchases = (143daily_purchases.set_index("InvoiceDate")144.groupby("CustomerID")145.resample("M")146.agg(147{"Amount": "sum", "Quantity": "sum", "frequency": "sum", "Country": "first"}148)149.reset_index()150)151152# Add cyclical temporal features153def prepare_temporal_features(input_window: pd.DataFrame) -> np.ndarray:154155month = input_window["InvoiceDate"].dt.month156month_sin = np.sin(2 * np.pi * month / 12)157month_cos = np.cos(2 * np.pi * month / 12)158is_quarter_start = (month % 3 == 1).astype(int)159160temporal_features = np.column_stack(161[162month,163input_window["InvoiceDate"].dt.year,164month_sin,165month_cos,166is_quarter_start,167]168)169return temporal_features170171# Prepare trend features with lagged values172def prepare_trend_features(input_window: pd.DataFrame, lag: int = 3) -> np.ndarray:173174lagged_data = pd.DataFrame()175for i in range(1, lag + 1):176lagged_data[f"Amount_lag_{i}"] = input_window["Amount"].shift(i)177lagged_data[f"Quantity_lag_{i}"] = input_window["Quantity"].shift(i)178lagged_data[f"frequency_lag_{i}"] = input_window["frequency"].shift(i)179180lagged_data = lagged_data.fillna(0)181182trend_features = np.column_stack(183[184input_window["Amount"].values,185input_window["Quantity"].values,186input_window["frequency"].values,187lagged_data.values,188]189)190return trend_features191192sequence_containers = {193"temporal_sequences": [],194"trend_sequences": [],195"static_features": [],196"output_sequences": [],197}198199# Process sequences for each customer200for customer_id, customer_data in monthly_purchases.groupby("CustomerID"):201customer_data = customer_data.sort_values("InvoiceDate")202sequence_ranges = (203len(customer_data) - input_sequence_length - output_sequence_length + 1204)205206country = customer_data["Country"].iloc[0]207208for i in range(sequence_ranges):209input_window = customer_data.iloc[i : i + input_sequence_length]210output_window = customer_data.iloc[211i212+ input_sequence_length : i213+ input_sequence_length214+ output_sequence_length215]216217if (218len(input_window) == input_sequence_length219and len(output_window) == output_sequence_length220):221temporal_features = prepare_temporal_features(input_window)222trend_features = prepare_trend_features(input_window)223224sequence_containers["temporal_sequences"].append(temporal_features)225sequence_containers["trend_sequences"].append(trend_features)226sequence_containers["static_features"].append(country)227sequence_containers["output_sequences"].append(228output_window["Amount"].values229)230231return {232"temporal_sequences": (233np.array(sequence_containers["temporal_sequences"], dtype=np.float32)234),235"trend_sequences": (236np.array(sequence_containers["trend_sequences"], dtype=np.float32)237),238"static_features": np.array(sequence_containers["static_features"]),239"output_sequences": (240np.array(sequence_containers["output_sequences"], dtype=np.float32)241),242}243244245# Transform data with input and output sequences into a Output dictionary246output = prepare_data_for_modeling(247df=transformed_data, input_sequence_length=6, output_sequence_length=6248)249250"""251## Scaling and Splitting252"""253254255def robust_scale(data):256"""257Min-Max scaling function since standard deviation is high.258"""259data = np.array(data)260data_min = np.min(data)261data_max = np.max(data)262scaled = (data - data_min) / (data_max - data_min)263return scaled264265266def create_temporal_splits_with_scaling(267prepared_data: Dict[str, np.ndarray],268test_ratio: float = 0.2,269val_ratio: float = 0.2,270):271total_sequences = len(prepared_data["trend_sequences"])272# Calculate split points273test_size = int(total_sequences * test_ratio)274val_size = int(total_sequences * val_ratio)275train_size = total_sequences - (test_size + val_size)276277# Scale trend sequences278trend_shape = prepared_data["trend_sequences"].shape279scaled_trends = np.zeros_like(prepared_data["trend_sequences"])280281# Scale each feature independently282for i in range(trend_shape[-1]):283scaled_trends[..., i] = robust_scale(prepared_data["trend_sequences"][..., i])284# Scale output sequences285scaled_outputs = robust_scale(prepared_data["output_sequences"])286287# Create splits288train_data = {289"trend_sequences": scaled_trends[:train_size],290"temporal_sequences": prepared_data["temporal_sequences"][:train_size],291"static_features": prepared_data["static_features"][:train_size],292"output_sequences": scaled_outputs[:train_size],293}294295val_data = {296"trend_sequences": scaled_trends[train_size : train_size + val_size],297"temporal_sequences": prepared_data["temporal_sequences"][298train_size : train_size + val_size299],300"static_features": prepared_data["static_features"][301train_size : train_size + val_size302],303"output_sequences": scaled_outputs[train_size : train_size + val_size],304}305306test_data = {307"trend_sequences": scaled_trends[train_size + val_size :],308"temporal_sequences": prepared_data["temporal_sequences"][309train_size + val_size :310],311"static_features": prepared_data["static_features"][train_size + val_size :],312"output_sequences": scaled_outputs[train_size + val_size :],313}314315return train_data, val_data, test_data316317318# Usage319train_data, val_data, test_data = create_temporal_splits_with_scaling(output)320321"""322## Evaluation323"""324325326def calculate_metrics(y_true, y_pred):327"""328Calculates RMSE, MAE and R²329"""330# Convert inputs to "float32"331y_true = ops.cast(y_true, dtype="float32")332y_pred = ops.cast(y_pred, dtype="float32")333334# RMSE335rmse = np.sqrt(np.mean(np.square(y_true - y_pred)))336337# R² (coefficient of determination)338ss_res = np.sum(np.square(y_true - y_pred))339ss_tot = np.sum(np.square(y_true - np.mean(y_true)))340r2 = 1 - (ss_res / ss_tot)341342return {"mae": np.mean(np.abs(y_true - y_pred)), "rmse": rmse, "r2": r2}343344345def plot_lorenz_analysis(y_true, y_pred):346"""347Plots Lorenz curves to show distribution of high and low value users348"""349# Convert to numpy arrays and flatten350y_true = np.array(y_true).flatten()351y_pred = np.array(y_pred).flatten()352353# Sort values in descending order (for high-value users analysis)354true_sorted = np.sort(-y_true)355pred_sorted = np.sort(-y_pred)356357# Calculate cumulative sums358true_cumsum = np.cumsum(true_sorted)359pred_cumsum = np.cumsum(pred_sorted)360361# Normalize to percentages362true_cumsum_pct = true_cumsum / true_cumsum[-1]363pred_cumsum_pct = pred_cumsum / pred_cumsum[-1]364365# Generate percentiles for x-axis366percentiles = np.linspace(0, 1, len(y_true))367368# Calculate Mutual Gini (area between curves)369mutual_gini = np.abs(370np.trapz(true_cumsum_pct, percentiles) - np.trapz(pred_cumsum_pct, percentiles)371)372373# Create plot374plt.figure(figsize=(10, 6))375plt.plot(percentiles, true_cumsum_pct, "g-", label="True Values")376plt.plot(percentiles, pred_cumsum_pct, "r-", label="Predicted Values")377plt.xlabel("Cumulative % of Users (Descending Order)")378plt.ylabel("Cumulative % of LTV")379plt.title("Lorenz Curves: True vs Predicted Values")380plt.legend()381plt.grid(True)382print(f"\nMutual Gini: {mutual_gini:.4f} (lower is better)")383plt.show()384385return mutual_gini386387388"""389## Hybrid Transformer / LSTM model architecture390391The hybrid nature of this model is particularly significant because it combines RNN's392ability to handle sequential data with Transformer's attention mechanisms for capturing393global patterns across countries and seasonality.394"""395396397def build_hybrid_model(398input_sequence_length: int,399output_sequence_length: int,400num_countries: int,401d_model: int = 8,402num_heads: int = 4,403):404405keras.utils.set_random_seed(seed=42)406407# Inputs408temporal_inputs = layers.Input(409shape=(input_sequence_length, 5), name="temporal_inputs"410)411trend_inputs = layers.Input(shape=(input_sequence_length, 12), name="trend_inputs")412country_inputs = layers.Input(413shape=(num_countries,), dtype="int32", name="country_inputs"414)415416# Process country features417country_embedding = layers.Embedding(418input_dim=num_countries,419output_dim=d_model,420mask_zero=False,421name="country_embedding",422)(423country_inputs424) # Output shape: (batch_size, 1, d_model)425426# Flatten the embedding output427country_embedding = layers.Flatten(name="flatten_country_embedding")(428country_embedding429)430431# Repeat the country embedding across timesteps432country_embedding_repeated = layers.RepeatVector(433input_sequence_length, name="repeat_country_embedding"434)(country_embedding)435436# Projection of temporal inputs to match Transformer dimensions437temporal_projection = layers.Dense(438d_model, activation="tanh", name="temporal_projection"439)(temporal_inputs)440441# Combine all features442combined_features = layers.Concatenate()(443[temporal_projection, country_embedding_repeated]444)445446transformer_output = combined_features447for _ in range(3):448transformer_output = TransformerEncoder(449intermediate_dim=16, num_heads=num_heads450)(transformer_output)451452lstm_output = layers.LSTM(units=64, name="lstm_trend")(trend_inputs)453454transformer_flattened = layers.GlobalAveragePooling1D(name="flatten_transformer")(455transformer_output456)457transformer_flattened = layers.Dense(1, activation="sigmoid")(transformer_flattened)458# Concatenate flattened Transformer output with LSTM output459merged_features = layers.Concatenate(name="concatenate_transformer_lstm")(460[transformer_flattened, lstm_output]461)462# Repeat the merged features to match the output sequence length463decoder_initial = layers.RepeatVector(464output_sequence_length, name="repeat_merged_features"465)(merged_features)466467decoder_lstm = layers.LSTM(468units=64,469return_sequences=True,470recurrent_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),471)(decoder_initial)472473# Output Dense layer474output = layers.Dense(units=1, activation="linear", name="output_dense")(475decoder_lstm476)477478model = Model(479inputs=[temporal_inputs, trend_inputs, country_inputs], outputs=output480)481482model.compile(483optimizer=keras.optimizers.Adam(learning_rate=0.001),484loss="mse",485metrics=["mse"],486)487488return model489490491# Create the hybrid model492model = build_hybrid_model(493input_sequence_length=6,494output_sequence_length=6,495num_countries=len(np.unique(train_data["static_features"])) + 1,496d_model=8,497num_heads=4,498)499500# Configure StringLookup501label_encoder = layers.StringLookup(output_mode="one_hot", num_oov_indices=1)502503# Adapt and encode504label_encoder.adapt(train_data["static_features"])505506train_static_encoded = label_encoder(train_data["static_features"])507val_static_encoded = label_encoder(val_data["static_features"])508test_static_encoded = label_encoder(test_data["static_features"])509510# Convert sequences with proper type casting511x_train_seq = np.asarray(train_data["trend_sequences"]).astype(np.float32)512x_val_seq = np.asarray(val_data["trend_sequences"]).astype(np.float32)513x_train_temporal = np.asarray(train_data["temporal_sequences"]).astype(np.float32)514x_val_temporal = np.asarray(val_data["temporal_sequences"]).astype(np.float32)515train_outputs = np.asarray(train_data["output_sequences"]).astype(np.float32)516val_outputs = np.asarray(val_data["output_sequences"]).astype(np.float32)517test_output = np.asarray(test_data["output_sequences"]).astype(np.float32)518# Training setup519keras.utils.set_random_seed(seed=42)520521history = model.fit(522[x_train_temporal, x_train_seq, train_static_encoded],523train_outputs,524validation_data=(525[x_val_temporal, x_val_seq, val_static_encoded],526val_data["output_sequences"].astype(np.float32),527),528epochs=20,529batch_size=30,530)531532# Make predictions533predictions = model.predict(534[535test_data["temporal_sequences"].astype(np.float32),536test_data["trend_sequences"].astype(np.float32),537test_static_encoded,538]539)540541# Calculate the predictions542predictions = np.squeeze(predictions)543544# Calculate basic metrics545hybrid_metrics = calculate_metrics(test_data["output_sequences"], predictions)546547# Plot Lorenz curves and get Mutual Gini548hybrid_mutual_gini = plot_lorenz_analysis(test_data["output_sequences"], predictions)549550"""551## Conclusion552553While LSTMs excel at sequence to sequence learning as demonstrated through the work of Sutskever, I., Vinyals,554O., & Le, Q. V. (2014) Sequence to sequence learning with neural networks.555The hybrid approach here enhances this foundation. The addition of attention mechanisms allows the model to adaptively556focus on relevant temporal/geographical patterns while maintaining the LSTM's inherent strengths in sequence learning.557This combination has proven especially effective for handling both periodic patterns and special events in time558series forecasting from Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., & Zhang, W. (2021).559Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting.560"""561562563