Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/audio/stft.py
3507 views
1
"""
2
Title: Audio Classification with the STFTSpectrogram layer
3
Author: [Mostafa M. Amin](https://mostafa-amin.com)
4
Date created: 2024/10/04
5
Last modified: 2024/10/04
6
Description: Introducing the `STFTSpectrogram` layer to extract spectrograms for audio classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Preprocessing audio as spectrograms is an essential step in the vast majority
14
of audio-based applications. Spectrograms represent the frequency content of a
15
signal over time, are widely used for this purpose. In this tutorial, we'll
16
demonstrate how to use the `STFTSpectrogram` layer in Keras to convert raw
17
audio waveforms into spectrograms **within the model**. We'll then feed
18
these spectrograms into an LSTM network followed by Dense layers to perform
19
audio classification on the Speech Commands dataset.
20
21
We will:
22
23
- Load the ESC-10 dataset.
24
- Preprocess the raw audio waveforms and generate spectrograms using
25
`STFTSpectrogram`.
26
- Build two models, one using spectrograms as 1D signals and the other is using
27
as images (2D signals) with a pretrained image model.
28
- Train and evaluate the models.
29
30
## Setup
31
32
### Importing the necessary libraries
33
"""
34
35
import os
36
37
os.environ["KERAS_BACKEND"] = "jax"
38
39
import keras
40
import matplotlib.pyplot as plt
41
import numpy as np
42
import pandas as pd
43
import scipy.io.wavfile
44
from keras import layers
45
from scipy.signal import resample
46
47
keras.utils.set_random_seed(41)
48
49
"""
50
### Define some variables
51
"""
52
53
BASE_DATA_DIR = "./datasets/esc-50_extracted/ESC-50-master/"
54
BATCH_SIZE = 16
55
NUM_CLASSES = 10
56
EPOCHS = 200
57
SAMPLE_RATE = 16000
58
59
"""
60
## Download and Preprocess the ESC-10 Dataset
61
62
We'll use the Dataset for Environmental Sound Classification dataset (ESC-10).
63
This dataset consists of five-second .wav files of environmental sounds.
64
65
### Download and Extract the dataset
66
"""
67
68
keras.utils.get_file(
69
"esc-50.zip",
70
"https://github.com/karoldvl/ESC-50/archive/master.zip",
71
cache_dir="./",
72
cache_subdir="datasets",
73
extract=True,
74
)
75
76
"""
77
### Read the CSV file
78
"""
79
80
pd_data = pd.read_csv(os.path.join(BASE_DATA_DIR, "meta", "esc50.csv"))
81
# filter ESC-50 to ESC-10 and reassign the targets
82
pd_data = pd_data[pd_data["esc10"]]
83
targets = sorted(pd_data["target"].unique().tolist())
84
assert len(targets) == NUM_CLASSES
85
old_target_to_new_target = {old: new for new, old in enumerate(targets)}
86
pd_data["target"] = pd_data["target"].map(lambda t: old_target_to_new_target[t])
87
pd_data
88
89
"""
90
### Define functions to read and preprocess the WAV files
91
"""
92
93
94
def read_wav_file(path, target_sr=SAMPLE_RATE):
95
sr, wav = scipy.io.wavfile.read(os.path.join(BASE_DATA_DIR, "audio", path))
96
wav = wav.astype(np.float32) / 32768.0 # normalize to [-1, 1]
97
num_samples = int(len(wav) * target_sr / sr) # resample to 16 kHz
98
wav = resample(wav, num_samples)
99
return wav[:, None] # Add a channel dimension (of size 1)
100
101
102
"""
103
Create a function that uses the `STFTSpectrogram` to compute a spectrogram,
104
then plots it.
105
"""
106
107
108
def plot_single_spectrogram(sample_wav_data):
109
spectrogram = layers.STFTSpectrogram(
110
mode="log",
111
frame_length=SAMPLE_RATE * 20 // 1000,
112
frame_step=SAMPLE_RATE * 5 // 1000,
113
fft_length=1024,
114
trainable=False,
115
)(sample_wav_data[None, ...])[0, ...]
116
117
# Plot the spectrogram
118
plt.imshow(spectrogram.T, origin="lower")
119
plt.title("Single Channel Spectrogram")
120
plt.xlabel("Time")
121
plt.ylabel("Frequency")
122
plt.show()
123
124
125
"""
126
Create a function that uses the `STFTSpectrogram` to compute three
127
spectrograms with multiple bandwidths, then aligns them as an image
128
with different channels, to get a multi-bandwith spectrogram,
129
then plots the spectrogram.
130
"""
131
132
133
def plot_multi_bandwidth_spectrogram(sample_wav_data):
134
# All spectrograms must use the same `fft_length`, `frame_step`, and
135
# `padding="same"` in order to produce spectrograms with identical shapes,
136
# hence aligning them together. `expand_dims` ensures that the shapes are
137
# compatible with image models.
138
139
spectrograms = np.concatenate(
140
[
141
layers.STFTSpectrogram(
142
mode="log",
143
frame_length=SAMPLE_RATE * x // 1000,
144
frame_step=SAMPLE_RATE * 5 // 1000,
145
fft_length=1024,
146
padding="same",
147
expand_dims=True,
148
)(sample_wav_data[None, ...])[0, ...]
149
for x in [5, 10, 20]
150
],
151
axis=-1,
152
).transpose([1, 0, 2])
153
154
# normalize each color channel for better viewing
155
mn = spectrograms.min(axis=(0, 1), keepdims=True)
156
mx = spectrograms.max(axis=(0, 1), keepdims=True)
157
spectrograms = (spectrograms - mn) / (mx - mn)
158
159
plt.imshow(spectrograms, origin="lower")
160
plt.title("Multi-bandwidth Spectrogram")
161
plt.xlabel("Time")
162
plt.ylabel("Frequency")
163
plt.show()
164
165
166
"""
167
Demonstrate a sample wav file.
168
"""
169
170
sample_wav_data = read_wav_file(pd_data["filename"].tolist()[52])
171
plt.plot(sample_wav_data[:, 0])
172
plt.show()
173
174
"""
175
Plot a Spectrogram
176
"""
177
178
plot_single_spectrogram(sample_wav_data)
179
180
"""
181
Plot a multi-bandwidth spectrogram
182
"""
183
184
plot_multi_bandwidth_spectrogram(sample_wav_data)
185
186
"""
187
### Define functions to construct a TF Dataset
188
"""
189
190
191
def read_dataset(df, folds):
192
msk = df["fold"].isin(folds)
193
filenames = df["filename"][msk]
194
targets = df["target"][msk].values
195
waves = np.array([read_wav_file(fil) for fil in filenames], dtype=np.float32)
196
return waves, targets
197
198
199
"""
200
### Create the datasets
201
"""
202
203
train_x, train_y = read_dataset(pd_data, [1, 2, 3])
204
valid_x, valid_y = read_dataset(pd_data, [4])
205
test_x, test_y = read_dataset(pd_data, [5])
206
207
"""
208
## Training the Models
209
210
In this tutorial we demonstrate the different usecases of the `STFTSpectrogram`
211
layer.
212
213
The first model will use a non-trainable `STFTSpectrogram` layer, so it is
214
intended purely for preprocessing. Additionally, the model will use 1D signals,
215
hence it make use of Conv1D layers.
216
217
The second model will use a trainable `STFTSpectrogram` layer with the
218
`expand_dims` option, which expands the shapes to be compatible with image
219
models.
220
221
### Create the 1D model
222
223
1. Create a non-trainable spectrograms, extracting a 1D time signal.
224
2. Apply `Conv1D` layers with `LayerNormalization` simialar to the
225
classic VGG design.
226
4. Apply global maximum pooling to have fixed set of features.
227
5. Add `Dense` layers to make the final predictions based on the features.
228
"""
229
230
model1d = keras.Sequential(
231
[
232
layers.InputLayer((None, 1)),
233
layers.STFTSpectrogram(
234
mode="log",
235
frame_length=SAMPLE_RATE * 40 // 1000,
236
frame_step=SAMPLE_RATE * 15 // 1000,
237
trainable=False,
238
),
239
layers.Conv1D(64, 64, activation="relu"),
240
layers.Conv1D(128, 16, activation="relu"),
241
layers.LayerNormalization(),
242
layers.MaxPooling1D(4),
243
layers.Conv1D(128, 8, activation="relu"),
244
layers.Conv1D(256, 8, activation="relu"),
245
layers.Conv1D(512, 4, activation="relu"),
246
layers.LayerNormalization(),
247
layers.Dropout(0.5),
248
layers.GlobalMaxPooling1D(),
249
layers.Dense(256, activation="relu"),
250
layers.Dense(256, activation="relu"),
251
layers.Dropout(0.5),
252
layers.Dense(NUM_CLASSES, activation="softmax"),
253
],
254
name="model_1d_non_trainble_stft",
255
)
256
model1d.compile(
257
optimizer=keras.optimizers.Adam(1e-5),
258
loss="sparse_categorical_crossentropy",
259
metrics=["accuracy"],
260
)
261
model1d.summary()
262
263
"""
264
Train the model and restore the best weights.
265
"""
266
267
history_model1d = model1d.fit(
268
train_x,
269
train_y,
270
batch_size=BATCH_SIZE,
271
validation_data=(valid_x, valid_y),
272
epochs=EPOCHS,
273
callbacks=[
274
keras.callbacks.EarlyStopping(
275
monitor="val_loss",
276
patience=EPOCHS,
277
restore_best_weights=True,
278
)
279
],
280
)
281
282
"""
283
### Create the 2D model
284
285
1. Create three spectrograms with multiple band-widths from the raw input.
286
2. Concatenate the three spectrograms to have three channels.
287
3. Load `MobileNet` and set the weights from the weights trained on `ImageNet`.
288
4. Apply global maximum pooling to have fixed set of features.
289
5. Add `Dense` layers to make the final predictions based on the features.
290
"""
291
292
input = layers.Input((None, 1))
293
spectrograms = [
294
layers.STFTSpectrogram(
295
mode="log",
296
frame_length=SAMPLE_RATE * frame_size // 1000,
297
frame_step=SAMPLE_RATE * 15 // 1000,
298
fft_length=2048,
299
padding="same",
300
expand_dims=True,
301
# trainable=True, # trainable by default
302
)(input)
303
for frame_size in [30, 40, 50] # frame size in milliseconds
304
]
305
306
multi_spectrograms = layers.Concatenate(axis=-1)(spectrograms)
307
308
img_model = keras.applications.MobileNet(include_top=False, pooling="max")
309
output = img_model(multi_spectrograms)
310
311
output = layers.Dropout(0.5)(output)
312
output = layers.Dense(256, activation="relu")(output)
313
output = layers.Dense(256, activation="relu")(output)
314
output = layers.Dense(NUM_CLASSES, activation="softmax")(output)
315
model2d = keras.Model(input, output, name="model_2d_trainble_stft")
316
317
model2d.compile(
318
optimizer=keras.optimizers.Adam(1e-4),
319
loss="sparse_categorical_crossentropy",
320
metrics=["accuracy"],
321
)
322
model2d.summary()
323
324
"""
325
Train the model and restore the best weights.
326
"""
327
328
history_model2d = model2d.fit(
329
train_x,
330
train_y,
331
batch_size=BATCH_SIZE,
332
validation_data=(valid_x, valid_y),
333
epochs=EPOCHS,
334
callbacks=[
335
keras.callbacks.EarlyStopping(
336
monitor="val_loss",
337
patience=EPOCHS,
338
restore_best_weights=True,
339
)
340
],
341
)
342
343
"""
344
### Plot Training History
345
"""
346
347
epochs_range = range(EPOCHS)
348
349
plt.figure(figsize=(14, 5))
350
plt.subplot(1, 2, 1)
351
plt.plot(
352
epochs_range,
353
history_model1d.history["accuracy"],
354
label="Training Accuracy,1D model with non-trainable STFT",
355
)
356
plt.plot(
357
epochs_range,
358
history_model1d.history["val_accuracy"],
359
label="Validation Accuracy, 1D model with non-trainable STFT",
360
)
361
plt.plot(
362
epochs_range,
363
history_model2d.history["accuracy"],
364
label="Training Accuracy, 2D model with trainable STFT",
365
)
366
plt.plot(
367
epochs_range,
368
history_model2d.history["val_accuracy"],
369
label="Validation Accuracy, 2D model with trainable STFT",
370
)
371
plt.legend(loc="lower right")
372
plt.title("Training and Validation Accuracy")
373
374
plt.subplot(1, 2, 2)
375
plt.plot(
376
epochs_range,
377
history_model1d.history["loss"],
378
label="Training Loss,1D model with non-trainable STFT",
379
)
380
plt.plot(
381
epochs_range,
382
history_model1d.history["val_loss"],
383
label="Validation Loss, 1D model with non-trainable STFT",
384
)
385
plt.plot(
386
epochs_range,
387
history_model2d.history["loss"],
388
label="Training Loss, 2D model with trainable STFT",
389
)
390
plt.plot(
391
epochs_range,
392
history_model2d.history["val_loss"],
393
label="Validation Loss, 2D model with trainable STFT",
394
)
395
plt.legend(loc="upper right")
396
plt.title("Training and Validation Loss")
397
plt.show()
398
399
"""
400
### Evaluate on Test Data
401
402
Running the models on the test set.
403
"""
404
405
_, test_acc = model1d.evaluate(test_x, test_y)
406
print(f"1D model wit non-trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%")
407
408
_, test_acc = model2d.evaluate(test_x, test_y)
409
print(f"2D model with trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%")
410
411