Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/structured_data/imbalanced_classification.py
3507 views
1
"""
2
Title: Imbalanced classification: credit card fraud detection
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/05/28
5
Last modified: 2020/04/17
6
Description: Demonstration of how to handle highly imbalanced classification problems.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example looks at the
14
[Kaggle Credit Card Fraud Detection](https://www.kaggle.com/mlg-ulb/creditcardfraud/)
15
dataset to demonstrate how
16
to train a classification model on data with highly imbalanced classes.
17
"""
18
19
"""
20
## First, vectorize the CSV data
21
"""
22
23
import csv
24
import numpy as np
25
26
# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/
27
fname = "/Users/fchollet/Downloads/creditcard.csv"
28
29
all_features = []
30
all_targets = []
31
with open(fname) as f:
32
for i, line in enumerate(f):
33
if i == 0:
34
print("HEADER:", line.strip())
35
continue # Skip header
36
fields = line.strip().split(",")
37
all_features.append([float(v.replace('"', "")) for v in fields[:-1]])
38
all_targets.append([int(fields[-1].replace('"', ""))])
39
if i == 1:
40
print("EXAMPLE FEATURES:", all_features[-1])
41
42
features = np.array(all_features, dtype="float32")
43
targets = np.array(all_targets, dtype="uint8")
44
print("features.shape:", features.shape)
45
print("targets.shape:", targets.shape)
46
47
"""
48
## Prepare a validation set
49
"""
50
51
num_val_samples = int(len(features) * 0.2)
52
train_features = features[:-num_val_samples]
53
train_targets = targets[:-num_val_samples]
54
val_features = features[-num_val_samples:]
55
val_targets = targets[-num_val_samples:]
56
57
print("Number of training samples:", len(train_features))
58
print("Number of validation samples:", len(val_features))
59
60
"""
61
## Analyze class imbalance in the targets
62
"""
63
64
counts = np.bincount(train_targets[:, 0])
65
print(
66
"Number of positive samples in training data: {} ({:.2f}% of total)".format(
67
counts[1], 100 * float(counts[1]) / len(train_targets)
68
)
69
)
70
71
weight_for_0 = 1.0 / counts[0]
72
weight_for_1 = 1.0 / counts[1]
73
74
"""
75
## Normalize the data using training set statistics
76
"""
77
78
mean = np.mean(train_features, axis=0)
79
train_features -= mean
80
val_features -= mean
81
std = np.std(train_features, axis=0)
82
train_features /= std
83
val_features /= std
84
85
"""
86
## Build a binary classification model
87
"""
88
89
import keras
90
91
model = keras.Sequential(
92
[
93
keras.Input(shape=train_features.shape[1:]),
94
keras.layers.Dense(256, activation="relu"),
95
keras.layers.Dense(256, activation="relu"),
96
keras.layers.Dropout(0.3),
97
keras.layers.Dense(256, activation="relu"),
98
keras.layers.Dropout(0.3),
99
keras.layers.Dense(1, activation="sigmoid"),
100
]
101
)
102
model.summary()
103
104
"""
105
## Train the model with `class_weight` argument
106
"""
107
108
metrics = [
109
keras.metrics.FalseNegatives(name="fn"),
110
keras.metrics.FalsePositives(name="fp"),
111
keras.metrics.TrueNegatives(name="tn"),
112
keras.metrics.TruePositives(name="tp"),
113
keras.metrics.Precision(name="precision"),
114
keras.metrics.Recall(name="recall"),
115
]
116
117
model.compile(
118
optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics
119
)
120
121
callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]
122
class_weight = {0: weight_for_0, 1: weight_for_1}
123
124
model.fit(
125
train_features,
126
train_targets,
127
batch_size=2048,
128
epochs=30,
129
verbose=2,
130
callbacks=callbacks,
131
validation_data=(val_features, val_targets),
132
class_weight=class_weight,
133
)
134
135
"""
136
## Conclusions
137
138
At the end of training, out of 56,961 validation transactions, we are:
139
140
- Correctly identifying 66 of them as fraudulent
141
- Missing 9 fraudulent transactions
142
- At the cost of incorrectly flagging 441 legitimate transactions
143
144
In the real world, one would put an even higher weight on class 1,
145
so as to reflect that False Negatives are more costly than False Positives.
146
147
Next time your credit card gets declined in an online purchase -- this is why.
148
149
"""
150
151