Path: blob/master/examples/audio/melgan_spectrogram_inversion.py
3507 views
"""1Title: MelGAN-based spectrogram inversion using feature matching2Author: [Darshan Deshpande](https://twitter.com/getdarshan)3Date created: 02/09/20214Last modified: 15/09/20215Description: Inversion of audio from mel-spectrograms using the MelGAN architecture and feature matching.6Accelerator: GPU7"""89"""10## Introduction1112Autoregressive vocoders have been ubiquitous for a majority of the history of speech processing,13but for most of their existence they have lacked parallelism.14[MelGAN](https://arxiv.org/abs/1910.06711) is a15non-autoregressive, fully convolutional vocoder architecture used for purposes ranging16from spectral inversion and speech enhancement to present-day state-of-the-art17speech synthesis when used as a decoder18with models like Tacotron2 or FastSpeech that convert text to mel spectrograms.1920In this tutorial, we will have a look at the MelGAN architecture and how it can achieve21fast spectral inversion, i.e. conversion of spectrograms to audio waves. The MelGAN22implemented in this tutorial is similar to the original implementation with only the23difference of method of padding for convolutions where we will use 'same' instead of24reflect padding.25"""2627"""28## Importing and Defining Hyperparameters29"""3031"""shell32pip install -qqq tensorflow_addons33pip install -qqq tensorflow-io34"""3536import tensorflow as tf37import tensorflow_io as tfio38from tensorflow import keras39from tensorflow.keras import layers40from tensorflow_addons import layers as addon_layers4142# Setting logger level to avoid input shape warnings43tf.get_logger().setLevel("ERROR")4445# Defining hyperparameters4647DESIRED_SAMPLES = 819248LEARNING_RATE_GEN = 1e-549LEARNING_RATE_DISC = 1e-650BATCH_SIZE = 165152mse = keras.losses.MeanSquaredError()53mae = keras.losses.MeanAbsoluteError()5455"""56## Loading the Dataset5758This example uses the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/).5960The LJSpeech dataset is primarily used for text-to-speech and consists of 13,100 discrete61speech samples taken from 7 non-fiction books, having a total length of approximately 2462hours. The MelGAN training is only concerned with the audio waves so we process only the63WAV files and ignore the audio annotations.64"""6566"""shell67wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz268tar -xf /content/LJSpeech-1.1.tar.bz269"""7071"""72We create a `tf.data.Dataset` to load and process the audio files on the fly.73The `preprocess()` function takes the file path as input and returns two instances of the74wave, one for input and one as the ground truth for comparison. The input wave will be75mapped to a spectrogram using the custom `MelSpec` layer as shown later in this example.76"""7778# Splitting the dataset into training and testing splits79wavs = tf.io.gfile.glob("LJSpeech-1.1/wavs/*.wav")80print(f"Number of audio files: {len(wavs)}")818283# Mapper function for loading the audio. This function returns two instances of the wave84def preprocess(filename):85audio = tf.audio.decode_wav(tf.io.read_file(filename), 1, DESIRED_SAMPLES).audio86return audio, audio878889# Create tf.data.Dataset objects and apply preprocessing90train_dataset = tf.data.Dataset.from_tensor_slices((wavs,))91train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)9293"""94## Defining custom layers for MelGAN9596The MelGAN architecture consists of 3 main modules:97981. The residual block992. Dilated convolutional block1003. Discriminator block101102103"""104105"""106Since the network takes a mel-spectrogram as input, we will create an additional custom107layer108which can convert the raw audio wave to a spectrogram on-the-fly. We use the raw audio109tensor from `train_dataset` and map it to a mel-spectrogram using the `MelSpec` layer110below.111"""112113# Custom keras layer for on-the-fly audio to spectrogram conversion114115116class MelSpec(layers.Layer):117def __init__(118self,119frame_length=1024,120frame_step=256,121fft_length=None,122sampling_rate=22050,123num_mel_channels=80,124freq_min=125,125freq_max=7600,126**kwargs,127):128super().__init__(**kwargs)129self.frame_length = frame_length130self.frame_step = frame_step131self.fft_length = fft_length132self.sampling_rate = sampling_rate133self.num_mel_channels = num_mel_channels134self.freq_min = freq_min135self.freq_max = freq_max136# Defining mel filter. This filter will be multiplied with the STFT output137self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix(138num_mel_bins=self.num_mel_channels,139num_spectrogram_bins=self.frame_length // 2 + 1,140sample_rate=self.sampling_rate,141lower_edge_hertz=self.freq_min,142upper_edge_hertz=self.freq_max,143)144145def call(self, audio, training=True):146# We will only perform the transformation during training.147if training:148# Taking the Short Time Fourier Transform. Ensure that the audio is padded.149# In the paper, the STFT output is padded using the 'REFLECT' strategy.150stft = tf.signal.stft(151tf.squeeze(audio, -1),152self.frame_length,153self.frame_step,154self.fft_length,155pad_end=True,156)157158# Taking the magnitude of the STFT output159magnitude = tf.abs(stft)160161# Multiplying the Mel-filterbank with the magnitude and scaling it using the db scale162mel = tf.matmul(tf.square(magnitude), self.mel_filterbank)163log_mel_spec = tfio.audio.dbscale(mel, top_db=80)164return log_mel_spec165else:166return audio167168def get_config(self):169config = super().get_config()170config.update(171{172"frame_length": self.frame_length,173"frame_step": self.frame_step,174"fft_length": self.fft_length,175"sampling_rate": self.sampling_rate,176"num_mel_channels": self.num_mel_channels,177"freq_min": self.freq_min,178"freq_max": self.freq_max,179}180)181return config182183184"""185The residual convolutional block extensively uses dilations and has a total receptive186field of 27 timesteps per block. The dilations must grow as a power of the `kernel_size`187to ensure reduction of hissing noise in the output. The network proposed by the paper is188as follows:189190191"""192193# Creating the residual stack block194195196def residual_stack(input, filters):197"""Convolutional residual stack with weight normalization.198199Args:200filters: int, determines filter size for the residual stack.201202Returns:203Residual stack output.204"""205c1 = addon_layers.WeightNormalization(206layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False207)(input)208lrelu1 = layers.LeakyReLU()(c1)209c2 = addon_layers.WeightNormalization(210layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False211)(lrelu1)212add1 = layers.Add()([c2, input])213214lrelu2 = layers.LeakyReLU()(add1)215c3 = addon_layers.WeightNormalization(216layers.Conv1D(filters, 3, dilation_rate=3, padding="same"), data_init=False217)(lrelu2)218lrelu3 = layers.LeakyReLU()(c3)219c4 = addon_layers.WeightNormalization(220layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False221)(lrelu3)222add2 = layers.Add()([add1, c4])223224lrelu4 = layers.LeakyReLU()(add2)225c5 = addon_layers.WeightNormalization(226layers.Conv1D(filters, 3, dilation_rate=9, padding="same"), data_init=False227)(lrelu4)228lrelu5 = layers.LeakyReLU()(c5)229c6 = addon_layers.WeightNormalization(230layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False231)(lrelu5)232add3 = layers.Add()([c6, add2])233234return add3235236237"""238Each convolutional block uses the dilations offered by the residual stack239and upsamples the input data by the `upsampling_factor`.240"""241242# Dilated convolutional block consisting of the Residual stack243244245def conv_block(input, conv_dim, upsampling_factor):246"""Dilated Convolutional Block with weight normalization.247248Args:249conv_dim: int, determines filter size for the block.250upsampling_factor: int, scale for upsampling.251252Returns:253Dilated convolution block.254"""255conv_t = addon_layers.WeightNormalization(256layers.Conv1DTranspose(conv_dim, 16, upsampling_factor, padding="same"),257data_init=False,258)(input)259lrelu1 = layers.LeakyReLU()(conv_t)260res_stack = residual_stack(lrelu1, conv_dim)261lrelu2 = layers.LeakyReLU()(res_stack)262return lrelu2263264265"""266The discriminator block consists of convolutions and downsampling layers. This block is267essential for the implementation of the feature matching technique.268269Each discriminator outputs a list of feature maps that will be compared during training270to compute the feature matching loss.271"""272273274def discriminator_block(input):275conv1 = addon_layers.WeightNormalization(276layers.Conv1D(16, 15, 1, "same"), data_init=False277)(input)278lrelu1 = layers.LeakyReLU()(conv1)279conv2 = addon_layers.WeightNormalization(280layers.Conv1D(64, 41, 4, "same", groups=4), data_init=False281)(lrelu1)282lrelu2 = layers.LeakyReLU()(conv2)283conv3 = addon_layers.WeightNormalization(284layers.Conv1D(256, 41, 4, "same", groups=16), data_init=False285)(lrelu2)286lrelu3 = layers.LeakyReLU()(conv3)287conv4 = addon_layers.WeightNormalization(288layers.Conv1D(1024, 41, 4, "same", groups=64), data_init=False289)(lrelu3)290lrelu4 = layers.LeakyReLU()(conv4)291conv5 = addon_layers.WeightNormalization(292layers.Conv1D(1024, 41, 4, "same", groups=256), data_init=False293)(lrelu4)294lrelu5 = layers.LeakyReLU()(conv5)295conv6 = addon_layers.WeightNormalization(296layers.Conv1D(1024, 5, 1, "same"), data_init=False297)(lrelu5)298lrelu6 = layers.LeakyReLU()(conv6)299conv7 = addon_layers.WeightNormalization(300layers.Conv1D(1, 3, 1, "same"), data_init=False301)(lrelu6)302return [lrelu1, lrelu2, lrelu3, lrelu4, lrelu5, lrelu6, conv7]303304305"""306### Create the generator307"""308309310def create_generator(input_shape):311inp = keras.Input(input_shape)312x = MelSpec()(inp)313x = layers.Conv1D(512, 7, padding="same")(x)314x = layers.LeakyReLU()(x)315x = conv_block(x, 256, 8)316x = conv_block(x, 128, 8)317x = conv_block(x, 64, 2)318x = conv_block(x, 32, 2)319x = addon_layers.WeightNormalization(320layers.Conv1D(1, 7, padding="same", activation="tanh")321)(x)322return keras.Model(inp, x)323324325# We use a dynamic input shape for the generator since the model is fully convolutional326generator = create_generator((None, 1))327generator.summary()328329"""330### Create the discriminator331"""332333334def create_discriminator(input_shape):335inp = keras.Input(input_shape)336out_map1 = discriminator_block(inp)337pool1 = layers.AveragePooling1D()(inp)338out_map2 = discriminator_block(pool1)339pool2 = layers.AveragePooling1D()(pool1)340out_map3 = discriminator_block(pool2)341return keras.Model(inp, [out_map1, out_map2, out_map3])342343344# We use a dynamic input shape for the discriminator345# This is done because the input shape for the generator is unknown346discriminator = create_discriminator((None, 1))347348discriminator.summary()349350"""351## Defining the loss functions352353**Generator Loss**354355The generator architecture uses a combination of two losses3563571. Mean Squared Error:358359This is the standard MSE generator loss calculated between ones and the outputs from the360discriminator with _N_ layers.361362<p align="center">363<img src="https://i.imgur.com/dz4JS3I.png" width=300px;></img>364</p>3653662. Feature Matching Loss:367368This loss involves extracting the outputs of every layer from the discriminator for both369the generator and ground truth and compare each layer output _k_ using Mean Absolute Error.370371<p align="center">372<img src="https://i.imgur.com/gEpSBar.png" width=400px;></img>373</p>374375**Discriminator Loss**376377The discriminator uses the Mean Absolute Error and compares the real data predictions378with ones and generated predictions with zeros.379380<p align="center">381<img src="https://i.imgur.com/bbEnJ3t.png" width=425px;></img>382</p>383"""384385# Generator loss386387388def generator_loss(real_pred, fake_pred):389"""Loss function for the generator.390391Args:392real_pred: Tensor, output of the ground truth wave passed through the discriminator.393fake_pred: Tensor, output of the generator prediction passed through the discriminator.394395Returns:396Loss for the generator.397"""398gen_loss = []399for i in range(len(fake_pred)):400gen_loss.append(mse(tf.ones_like(fake_pred[i][-1]), fake_pred[i][-1]))401402return tf.reduce_mean(gen_loss)403404405def feature_matching_loss(real_pred, fake_pred):406"""Implements the feature matching loss.407408Args:409real_pred: Tensor, output of the ground truth wave passed through the discriminator.410fake_pred: Tensor, output of the generator prediction passed through the discriminator.411412Returns:413Feature Matching Loss.414"""415fm_loss = []416for i in range(len(fake_pred)):417for j in range(len(fake_pred[i]) - 1):418fm_loss.append(mae(real_pred[i][j], fake_pred[i][j]))419420return tf.reduce_mean(fm_loss)421422423def discriminator_loss(real_pred, fake_pred):424"""Implements the discriminator loss.425426Args:427real_pred: Tensor, output of the ground truth wave passed through the discriminator.428fake_pred: Tensor, output of the generator prediction passed through the discriminator.429430Returns:431Discriminator Loss.432"""433real_loss, fake_loss = [], []434for i in range(len(real_pred)):435real_loss.append(mse(tf.ones_like(real_pred[i][-1]), real_pred[i][-1]))436fake_loss.append(mse(tf.zeros_like(fake_pred[i][-1]), fake_pred[i][-1]))437438# Calculating the final discriminator loss after scaling439disc_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)440return disc_loss441442443"""444Defining the MelGAN model for training.445This subclass overrides the `train_step()` method to implement the training logic.446"""447448449class MelGAN(keras.Model):450def __init__(self, generator, discriminator, **kwargs):451"""MelGAN trainer class452453Args:454generator: keras.Model, Generator model455discriminator: keras.Model, Discriminator model456"""457super().__init__(**kwargs)458self.generator = generator459self.discriminator = discriminator460461def compile(462self,463gen_optimizer,464disc_optimizer,465generator_loss,466feature_matching_loss,467discriminator_loss,468):469"""MelGAN compile method.470471Args:472gen_optimizer: keras.optimizer, optimizer to be used for training473disc_optimizer: keras.optimizer, optimizer to be used for training474generator_loss: callable, loss function for generator475feature_matching_loss: callable, loss function for feature matching476discriminator_loss: callable, loss function for discriminator477"""478super().compile()479480# Optimizers481self.gen_optimizer = gen_optimizer482self.disc_optimizer = disc_optimizer483484# Losses485self.generator_loss = generator_loss486self.feature_matching_loss = feature_matching_loss487self.discriminator_loss = discriminator_loss488489# Trackers490self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")491self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")492493def train_step(self, batch):494x_batch_train, y_batch_train = batch495496with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:497# Generating the audio wave498gen_audio_wave = generator(x_batch_train, training=True)499500# Generating the features using the discriminator501real_pred = discriminator(y_batch_train)502fake_pred = discriminator(gen_audio_wave)503504# Calculating the generator losses505gen_loss = generator_loss(real_pred, fake_pred)506fm_loss = feature_matching_loss(real_pred, fake_pred)507508# Calculating final generator loss509gen_fm_loss = gen_loss + 10 * fm_loss510511# Calculating the discriminator losses512disc_loss = discriminator_loss(real_pred, fake_pred)513514# Calculating and applying the gradients for generator and discriminator515grads_gen = gen_tape.gradient(gen_fm_loss, generator.trainable_weights)516grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_weights)517gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_weights))518disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_weights))519520self.gen_loss_tracker.update_state(gen_fm_loss)521self.disc_loss_tracker.update_state(disc_loss)522523return {524"gen_loss": self.gen_loss_tracker.result(),525"disc_loss": self.disc_loss_tracker.result(),526}527528529"""530## Training531532The paper suggests that the training with dynamic shapes takes around 400,000 steps (~500533epochs). For this example, we will run it only for a single epoch (819 steps).534Longer training time (greater than 300 epochs) will almost certainly provide better results.535"""536537gen_optimizer = keras.optimizers.Adam(538LEARNING_RATE_GEN, beta_1=0.5, beta_2=0.9, clipnorm=1539)540disc_optimizer = keras.optimizers.Adam(541LEARNING_RATE_DISC, beta_1=0.5, beta_2=0.9, clipnorm=1542)543544# Start training545generator = create_generator((None, 1))546discriminator = create_discriminator((None, 1))547548mel_gan = MelGAN(generator, discriminator)549mel_gan.compile(550gen_optimizer,551disc_optimizer,552generator_loss,553feature_matching_loss,554discriminator_loss,555)556mel_gan.fit(557train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), epochs=1558)559560"""561## Testing the model562563The trained model can now be used for real time text-to-speech translation tasks.564To test how fast the MelGAN inference can be, let us take a sample audio mel-spectrogram565and convert it. Note that the actual model pipeline will not include the `MelSpec` layer566and hence this layer will be disabled during inference. The inference input will be a567mel-spectrogram processed similar to the `MelSpec` layer configuration.568569For testing this, we will create a randomly uniformly distributed tensor to simulate the570behavior of the inference pipeline.571"""572573# Sampling a random tensor to mimic a batch of 128 spectrograms of shape [50, 80]574audio_sample = tf.random.uniform([128, 50, 80])575576"""577Timing the inference speed of a single sample. Running this, you can see that the average578inference time per spectrogram ranges from 8 milliseconds to 10 milliseconds on a K80 GPU which is579pretty fast.580"""581pred = generator.predict(audio_sample, batch_size=32, verbose=1)582"""583## Conclusion584585The MelGAN is a highly effective architecture for spectral inversion that has a Mean586Opinion Score (MOS) of 3.61 that considerably outperforms the Griffin587Lim algorithm having a MOS of just 1.57. In contrast with this, the MelGAN compares with588the state-of-the-art WaveGlow and WaveNet architectures on text-to-speech and speech589enhancement tasks on590the LJSpeech and VCTK datasets <sup>[1]</sup>.591592This tutorial highlights:5935941. The advantages of using dilated convolutions that grow with the filter size5952. Implementation of a custom layer for on-the-fly conversion of audio waves to596mel-spectrograms5973. Effectiveness of using the feature matching loss function for training GAN generators.598599Further reading6006011. [MelGAN paper](https://arxiv.org/abs/1910.06711) (Kundan Kumar et al.) to602understand the reasoning behind the architecture and training process6032. For in-depth understanding of the feature matching loss, you can refer to [Improved604Techniques for Training GANs](https://arxiv.org/abs/1606.03498) (Tim Salimans et605al.).606"""607608609