Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/neural_tangent_kernels.py
Views: 712
# -*- coding: utf-8 -*-1"""2Neural Tangent Kernels3======================45The neural tangent kernel (NTK) is a kernel that describes6`how a neural network evolves during training <https://en.wikipedia.org/wiki/Neural_tangent_kernel>`_.7There has been a lot of research around it `in recent years <https://arxiv.org/abs/1806.07572>`_.8This tutorial, inspired by the implementation of `NTKs in JAX <https://github.com/google/neural-tangents>`_9(see `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_ for details),10demonstrates how to easily compute this quantity using ``torch.func``,11composable function transforms for PyTorch.1213.. note::1415This tutorial requires PyTorch 2.0.0 or later.1617Setup18-----1920First, some setup. Let's define a simple CNN that we wish to compute the NTK of.21"""2223import torch24import torch.nn as nn25from torch.func import functional_call, vmap, vjp, jvp, jacrev26device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'2728class CNN(nn.Module):29def __init__(self):30super(CNN, self).__init__()31self.conv1 = nn.Conv2d(3, 32, (3, 3))32self.conv2 = nn.Conv2d(32, 32, (3, 3))33self.conv3 = nn.Conv2d(32, 32, (3, 3))34self.fc = nn.Linear(21632, 10)3536def forward(self, x):37x = self.conv1(x)38x = x.relu()39x = self.conv2(x)40x = x.relu()41x = self.conv3(x)42x = x.flatten(1)43x = self.fc(x)44return x4546######################################################################47# And let's generate some random data4849x_train = torch.randn(20, 3, 32, 32, device=device)50x_test = torch.randn(5, 3, 32, 32, device=device)5152######################################################################53# Create a function version of the model54# --------------------------------------55#56# ``torch.func`` transforms operate on functions. In particular, to compute the NTK,57# we will need a function that accepts the parameters of the model and a single58# input (as opposed to a batch of inputs!) and returns a single output.59#60# We'll use ``torch.func.functional_call``, which allows us to call an ``nn.Module``61# using different parameters/buffers, to help accomplish the first step.62#63# Keep in mind that the model was originally written to accept a batch of input64# data points. In our CNN example, there are no inter-batch operations. That65# is, each data point in the batch is independent of other data points. With66# this assumption in mind, we can easily generate a function that evaluates the67# model on a single data point:686970net = CNN().to(device)7172# Detaching the parameters because we won't be calling Tensor.backward().73params = {k: v.detach() for k, v in net.named_parameters()}7475def fnet_single(params, x):76return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)7778######################################################################79# Compute the NTK: method 1 (Jacobian contraction)80# ------------------------------------------------81# We're ready to compute the empirical NTK. The empirical NTK for two data82# points :math:`x_1` and :math:`x_2` is defined as the matrix product between the Jacobian83# of the model evaluated at :math:`x_1` and the Jacobian of the model evaluated at84# :math:`x_2`:85#86# .. math::87#88# J_{net}(x_1) J_{net}^T(x_2)89#90# In the batched case where :math:`x_1` is a batch of data points and :math:`x_2` is a91# batch of data points, then we want the matrix product between the Jacobians92# of all combinations of data points from :math:`x_1` and :math:`x_2`.93#94# The first method consists of doing just that - computing the two Jacobians,95# and contracting them. Here's how to compute the NTK in the batched case:9697def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):98# Compute J(x1)99jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)100jac1 = jac1.values()101jac1 = [j.flatten(2) for j in jac1]102103# Compute J(x2)104jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)105jac2 = jac2.values()106jac2 = [j.flatten(2) for j in jac2]107108# Compute J(x1) @ J(x2).T109result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])110result = result.sum(0)111return result112113result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)114print(result.shape)115116######################################################################117# In some cases, you may only want the diagonal or the trace of this quantity,118# especially if you know beforehand that the network architecture results in an119# NTK where the non-diagonal elements can be approximated by zero. It's easy to120# adjust the above function to do that:121122def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):123# Compute J(x1)124jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)125jac1 = jac1.values()126jac1 = [j.flatten(2) for j in jac1]127128# Compute J(x2)129jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)130jac2 = jac2.values()131jac2 = [j.flatten(2) for j in jac2]132133# Compute J(x1) @ J(x2).T134einsum_expr = None135if compute == 'full':136einsum_expr = 'Naf,Mbf->NMab'137elif compute == 'trace':138einsum_expr = 'Naf,Maf->NM'139elif compute == 'diagonal':140einsum_expr = 'Naf,Maf->NMa'141else:142assert False143144result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])145result = result.sum(0)146return result147148result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')149print(result.shape)150151######################################################################152# The asymptotic time complexity of this method is :math:`N O [FP]` (time to153# compute the Jacobians) + :math:`N^2 O^2 P` (time to contract the Jacobians),154# where :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O`155# is the model's output size, :math:`P` is the total number of parameters, and156# :math:`[FP]` is the cost of a single forward pass through the model. See157# section 3.2 in158# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_159# for details.160#161# Compute the NTK: method 2 (NTK-vector products)162# -----------------------------------------------163#164# The next method we will discuss is a way to compute the NTK using NTK-vector165# products.166#167# This method reformulates NTK as a stack of NTK-vector products applied to168# columns of an identity matrix :math:`I_O` of size :math:`O\times O`169# (where :math:`O` is the output size of the model):170#171# .. math::172#173# J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O},174#175# where :math:`e_o\in \mathbb{R}^O` are column vectors of the identity matrix176# :math:`I_O`.177#178# - Let :math:`\textrm{vjp}_o = J_{net}^T(x_2) e_o`. We can use179# a vector-Jacobian product to compute this.180# - Now, consider :math:`J_{net}(x_1) \textrm{vjp}_o`. This is a181# Jacobian-vector product!182# - Finally, we can run the above computation in parallel over all183# columns :math:`e_o` of :math:`I_O` using ``vmap``.184#185# This suggests that we can use a combination of reverse-mode AD (to compute186# the vector-Jacobian product) and forward-mode AD (to compute the187# Jacobian-vector product) to compute the NTK.188#189# Let's code that up:190191def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):192def get_ntk(x1, x2):193def func_x1(params):194return func(params, x1)195196def func_x2(params):197return func(params, x2)198199output, vjp_fn = vjp(func_x1, params)200201def get_ntk_slice(vec):202# This computes ``vec @ J(x2).T``203# `vec` is some unit vector (a single slice of the Identity matrix)204vjps = vjp_fn(vec)205# This computes ``J(X1) @ vjps``206_, jvps = jvp(func_x2, (params,), vjps)207return jvps208209# Here's our identity matrix210basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)211return vmap(get_ntk_slice)(basis)212213# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2214# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,215# we actually wish to compute the NTK between every pair of data points216# between {x1} and {x2}. That's what the ``vmaps`` here do.217result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)218219if compute == 'full':220return result221if compute == 'trace':222return torch.einsum('NMKK->NM', result)223if compute == 'diagonal':224return torch.einsum('NMKK->NMK', result)225226# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy227with torch.backends.cudnn.flags(allow_tf32=False):228result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)229result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)230231assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)232233######################################################################234# Our code for ``empirical_ntk_ntk_vps`` looks like a direct translation from235# the math above! This showcases the power of function transforms: good luck236# trying to write an efficient version of the above by only using237# ``torch.autograd.grad``.238#239# The asymptotic time complexity of this method is :math:`N^2 O [FP]`, where240# :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O` is the241# model's output size, and :math:`[FP]` is the cost of a single forward pass242# through the model. Hence this method performs more forward passes through the243# network than method 1, Jacobian contraction (:math:`N^2 O` instead of244# :math:`N O`), but avoids the contraction cost altogether (no :math:`N^2 O^2 P`245# term, where :math:`P` is the total number of model's parameters). Therefore,246# this method is preferable when :math:`O P` is large relative to :math:`[FP]`,247# such as fully-connected (not convolutional) models with many outputs :math:`O`.248# Memory-wise, both methods should be comparable. See section 3.3 in249# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_250# for details.251252253