Path: blob/master/examples/structured_data/imbalanced_classification.py
3507 views
"""1Title: Imbalanced classification: credit card fraud detection2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/05/284Last modified: 2020/04/175Description: Demonstration of how to handle highly imbalanced classification problems.6Accelerator: GPU7"""89"""10## Introduction1112This example looks at the13[Kaggle Credit Card Fraud Detection](https://www.kaggle.com/mlg-ulb/creditcardfraud/)14dataset to demonstrate how15to train a classification model on data with highly imbalanced classes.16"""1718"""19## First, vectorize the CSV data20"""2122import csv23import numpy as np2425# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/26fname = "/Users/fchollet/Downloads/creditcard.csv"2728all_features = []29all_targets = []30with open(fname) as f:31for i, line in enumerate(f):32if i == 0:33print("HEADER:", line.strip())34continue # Skip header35fields = line.strip().split(",")36all_features.append([float(v.replace('"', "")) for v in fields[:-1]])37all_targets.append([int(fields[-1].replace('"', ""))])38if i == 1:39print("EXAMPLE FEATURES:", all_features[-1])4041features = np.array(all_features, dtype="float32")42targets = np.array(all_targets, dtype="uint8")43print("features.shape:", features.shape)44print("targets.shape:", targets.shape)4546"""47## Prepare a validation set48"""4950num_val_samples = int(len(features) * 0.2)51train_features = features[:-num_val_samples]52train_targets = targets[:-num_val_samples]53val_features = features[-num_val_samples:]54val_targets = targets[-num_val_samples:]5556print("Number of training samples:", len(train_features))57print("Number of validation samples:", len(val_features))5859"""60## Analyze class imbalance in the targets61"""6263counts = np.bincount(train_targets[:, 0])64print(65"Number of positive samples in training data: {} ({:.2f}% of total)".format(66counts[1], 100 * float(counts[1]) / len(train_targets)67)68)6970weight_for_0 = 1.0 / counts[0]71weight_for_1 = 1.0 / counts[1]7273"""74## Normalize the data using training set statistics75"""7677mean = np.mean(train_features, axis=0)78train_features -= mean79val_features -= mean80std = np.std(train_features, axis=0)81train_features /= std82val_features /= std8384"""85## Build a binary classification model86"""8788import keras8990model = keras.Sequential(91[92keras.Input(shape=train_features.shape[1:]),93keras.layers.Dense(256, activation="relu"),94keras.layers.Dense(256, activation="relu"),95keras.layers.Dropout(0.3),96keras.layers.Dense(256, activation="relu"),97keras.layers.Dropout(0.3),98keras.layers.Dense(1, activation="sigmoid"),99]100)101model.summary()102103"""104## Train the model with `class_weight` argument105"""106107metrics = [108keras.metrics.FalseNegatives(name="fn"),109keras.metrics.FalsePositives(name="fp"),110keras.metrics.TrueNegatives(name="tn"),111keras.metrics.TruePositives(name="tp"),112keras.metrics.Precision(name="precision"),113keras.metrics.Recall(name="recall"),114]115116model.compile(117optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics118)119120callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]121class_weight = {0: weight_for_0, 1: weight_for_1}122123model.fit(124train_features,125train_targets,126batch_size=2048,127epochs=30,128verbose=2,129callbacks=callbacks,130validation_data=(val_features, val_targets),131class_weight=class_weight,132)133134"""135## Conclusions136137At the end of training, out of 56,961 validation transactions, we are:138139- Correctly identifying 66 of them as fraudulent140- Missing 9 fraudulent transactions141- At the cost of incorrectly flagging 441 legitimate transactions142143In the real world, one would put an even higher weight on class 1,144so as to reflect that False Negatives are more costly than False Positives.145146Next time your credit card gets declined in an online purchase -- this is why.147148"""149150151