Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/nerf.py
3507 views
1
"""
2
Title: 3D volumetric rendering with NeRF
3
Authors: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
4
Date created: 2021/08/09
5
Last modified: 2023/11/13
6
Description: Minimal implementation of volumetric rendering as shown in NeRF.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we present a minimal implementation of the research paper
14
[**NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis**](https://arxiv.org/abs/2003.08934)
15
by Ben Mildenhall et. al. The authors have proposed an ingenious way
16
to *synthesize novel views of a scene* by modelling the *volumetric
17
scene function* through a neural network.
18
19
To help you understand this intuitively, let's start with the following question:
20
*would it be possible to give to a neural
21
network the position of a pixel in an image, and ask the network
22
to predict the color at that position?*
23
24
| ![2d-train](https://i.imgur.com/DQM92vN.png) |
25
| :---: |
26
| **Figure 1**: A neural network being given coordinates of an image
27
as input and asked to predict the color at the coordinates. |
28
29
The neural network would hypothetically *memorize* (overfit on) the
30
image. This means that our neural network would have encoded the entire image
31
in its weights. We could query the neural network with each position,
32
and it would eventually reconstruct the entire image.
33
34
| ![2d-test](https://i.imgur.com/6Qz5Hp1.png) |
35
| :---: |
36
| **Figure 2**: The trained neural network recreates the image from scratch. |
37
38
A question now arises, how do we extend this idea to learn a 3D
39
volumetric scene? Implementing a similar process as above would
40
require the knowledge of every voxel (volume pixel). Turns out, this
41
is quite a challenging task to do.
42
43
The authors of the paper propose a minimal and elegant way to learn a
44
3D scene using a few images of the scene. They discard the use of
45
voxels for training. The network learns to model the volumetric scene,
46
thus generating novel views (images) of the 3D scene that the model
47
was not shown at training time.
48
49
There are a few prerequisites one needs to understand to fully
50
appreciate the process. We structure the example in such a way that
51
you will have all the required knowledge before starting the
52
implementation.
53
"""
54
55
"""
56
## Setup
57
"""
58
import os
59
60
os.environ["KERAS_BACKEND"] = "tensorflow"
61
62
# Setting random seed to obtain reproducible results.
63
import tensorflow as tf
64
65
tf.random.set_seed(42)
66
67
import keras
68
from keras import layers
69
70
import os
71
import glob
72
import imageio.v2 as imageio
73
import numpy as np
74
from tqdm import tqdm
75
import matplotlib.pyplot as plt
76
77
# Initialize global variables.
78
AUTO = tf.data.AUTOTUNE
79
BATCH_SIZE = 5
80
NUM_SAMPLES = 32
81
POS_ENCODE_DIMS = 16
82
EPOCHS = 20
83
84
"""
85
## Download and load the data
86
87
The `npz` data file contains images, camera poses, and a focal length.
88
The images are taken from multiple camera angles as shown in
89
**Figure 3**.
90
91
| ![camera-angles](https://i.imgur.com/FLsi2is.png) |
92
| :---: |
93
| **Figure 3**: Multiple camera angles <br>
94
[Source: NeRF](https://arxiv.org/abs/2003.08934) |
95
96
97
To understand camera poses in this context we have to first allow
98
ourselves to think that a *camera is a mapping between the real-world
99
and the 2-D image*.
100
101
| ![mapping](https://www.mathworks.com/help/vision/ug/calibration_coordinate_blocks.png) |
102
| :---: |
103
| **Figure 4**: 3-D world to 2-D image mapping through a camera <br>
104
[Source: Mathworks](https://www.mathworks.com/help/vision/ug/camera-calibration.html) |
105
106
Consider the following equation:
107
108
<img src="https://i.imgur.com/TQHKx5v.pngg" width="100" height="50"/>
109
110
Where **x** is the 2-D image point, **X** is the 3-D world point and
111
**P** is the camera-matrix. **P** is a 3 x 4 matrix that plays the
112
crucial role of mapping the real world object onto an image plane.
113
114
<img src="https://i.imgur.com/chvJct5.png" width="300" height="100"/>
115
116
The camera-matrix is an *affine transform matrix* that is
117
concatenated with a 3 x 1 column `[image height, image width, focal length]`
118
to produce the *pose matrix*. This matrix is of
119
dimensions 3 x 5 where the first 3 x 3 block is in the camera’s point
120
of view. The axes are `[down, right, backwards]` or `[-y, x, z]`
121
where the camera is facing forwards `-z`.
122
123
| ![camera-mapping](https://i.imgur.com/kvjqbiO.png) |
124
| :---: |
125
| **Figure 5**: The affine transformation. |
126
127
The COLMAP frame is `[right, down, forwards]` or `[x, -y, -z]`. Read
128
more about COLMAP [here](https://colmap.github.io/).
129
"""
130
131
# Download the data if it does not already exist.
132
url = (
133
"http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz"
134
)
135
data = keras.utils.get_file(origin=url)
136
137
data = np.load(data)
138
images = data["images"]
139
im_shape = images.shape
140
(num_images, H, W, _) = images.shape
141
(poses, focal) = (data["poses"], data["focal"])
142
143
# Plot a random image from the dataset for visualization.
144
plt.imshow(images[np.random.randint(low=0, high=num_images)])
145
plt.show()
146
147
"""
148
## Data pipeline
149
150
Now that you've understood the notion of camera matrix
151
and the mapping from a 3D scene to 2D images,
152
let's talk about the inverse mapping, i.e. from 2D image to the 3D scene.
153
154
We'll need to talk about volumetric rendering with ray casting and tracing,
155
which are common computer graphics techniques.
156
This section will help you get to speed with these techniques.
157
158
Consider an image with `N` pixels. We shoot a ray through each pixel
159
and sample some points on the ray. A ray is commonly parameterized by
160
the equation `r(t) = o + td` where `t` is the parameter, `o` is the
161
origin and `d` is the unit directional vector as shown in **Figure 6**.
162
163
| ![img](https://i.imgur.com/ywrqlzt.gif) |
164
| :---: |
165
| **Figure 6**: `r(t) = o + td` where t is 3 |
166
167
In **Figure 7**, we consider a ray, and we sample some random points on
168
the ray. These sample points each have a unique location `(x, y, z)`
169
and the ray has a viewing angle `(theta, phi)`. The viewing angle is
170
particularly interesting as we can shoot a ray through a single pixel
171
in a lot of different ways, each with a unique viewing angle. Another
172
interesting thing to notice here is the noise that is added to the
173
sampling process. We add a uniform noise to each sample so that the
174
samples correspond to a continuous distribution. In **Figure 7** the
175
blue points are the evenly distributed samples and the white points
176
`(t1, t2, t3)` are randomly placed between the samples.
177
178
| ![img](https://i.imgur.com/r9TS2wv.gif) |
179
| :---: |
180
| **Figure 7**: Sampling the points from a ray. |
181
182
**Figure 8** showcases the entire sampling process in 3D, where you
183
can see the rays coming out of the white image. This means that each
184
pixel will have its corresponding rays and each ray will be sampled at
185
distinct points.
186
187
| ![3-d rays](https://i.imgur.com/hr4D2g2.gif) |
188
| :---: |
189
| **Figure 8**: Shooting rays from all the pixels of an image in 3-D |
190
191
These sampled points act as the input to the NeRF model. The model is
192
then asked to predict the RGB color and the volume density at that
193
point.
194
195
| ![3-Drender](https://i.imgur.com/HHb6tlQ.png) |
196
| :---: |
197
| **Figure 9**: Data pipeline <br>
198
[Source: NeRF](https://arxiv.org/abs/2003.08934) |
199
200
"""
201
202
203
def encode_position(x):
204
"""Encodes the position into its corresponding Fourier feature.
205
206
Args:
207
x: The input coordinate.
208
209
Returns:
210
Fourier features tensors of the position.
211
"""
212
positions = [x]
213
for i in range(POS_ENCODE_DIMS):
214
for fn in [tf.sin, tf.cos]:
215
positions.append(fn(2.0**i * x))
216
return tf.concat(positions, axis=-1)
217
218
219
def get_rays(height, width, focal, pose):
220
"""Computes origin point and direction vector of rays.
221
222
Args:
223
height: Height of the image.
224
width: Width of the image.
225
focal: The focal length between the images and the camera.
226
pose: The pose matrix of the camera.
227
228
Returns:
229
Tuple of origin point and direction vector for rays.
230
"""
231
# Build a meshgrid for the rays.
232
i, j = tf.meshgrid(
233
tf.range(width, dtype=tf.float32),
234
tf.range(height, dtype=tf.float32),
235
indexing="xy",
236
)
237
238
# Normalize the x axis coordinates.
239
transformed_i = (i - width * 0.5) / focal
240
241
# Normalize the y axis coordinates.
242
transformed_j = (j - height * 0.5) / focal
243
244
# Create the direction unit vectors.
245
directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
246
247
# Get the camera matrix.
248
camera_matrix = pose[:3, :3]
249
height_width_focal = pose[:3, -1]
250
251
# Get origins and directions for the rays.
252
transformed_dirs = directions[..., None, :]
253
camera_dirs = transformed_dirs * camera_matrix
254
ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
255
ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
256
257
# Return the origins and directions.
258
return (ray_origins, ray_directions)
259
260
261
def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
262
"""Renders the rays and flattens it.
263
264
Args:
265
ray_origins: The origin points for rays.
266
ray_directions: The direction unit vectors for the rays.
267
near: The near bound of the volumetric scene.
268
far: The far bound of the volumetric scene.
269
num_samples: Number of sample points in a ray.
270
rand: Choice for randomising the sampling strategy.
271
272
Returns:
273
Tuple of flattened rays and sample points on each rays.
274
"""
275
# Compute 3D query points.
276
# Equation: r(t) = o+td -> Building the "t" here.
277
t_vals = tf.linspace(near, far, num_samples)
278
if rand:
279
# Inject uniform noise into sample space to make the sampling
280
# continuous.
281
shape = list(ray_origins.shape[:-1]) + [num_samples]
282
noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
283
t_vals = t_vals + noise
284
285
# Equation: r(t) = o + td -> Building the "r" here.
286
rays = ray_origins[..., None, :] + (
287
ray_directions[..., None, :] * t_vals[..., None]
288
)
289
rays_flat = tf.reshape(rays, [-1, 3])
290
rays_flat = encode_position(rays_flat)
291
return (rays_flat, t_vals)
292
293
294
def map_fn(pose):
295
"""Maps individual pose to flattened rays and sample points.
296
297
Args:
298
pose: The pose matrix of the camera.
299
300
Returns:
301
Tuple of flattened rays and sample points corresponding to the
302
camera pose.
303
"""
304
(ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
305
(rays_flat, t_vals) = render_flat_rays(
306
ray_origins=ray_origins,
307
ray_directions=ray_directions,
308
near=2.0,
309
far=6.0,
310
num_samples=NUM_SAMPLES,
311
rand=True,
312
)
313
return (rays_flat, t_vals)
314
315
316
# Create the training split.
317
split_index = int(num_images * 0.8)
318
319
# Split the images into training and validation.
320
train_images = images[:split_index]
321
val_images = images[split_index:]
322
323
# Split the poses into training and validation.
324
train_poses = poses[:split_index]
325
val_poses = poses[split_index:]
326
327
# Make the training pipeline.
328
train_img_ds = tf.data.Dataset.from_tensor_slices(train_images)
329
train_pose_ds = tf.data.Dataset.from_tensor_slices(train_poses)
330
train_ray_ds = train_pose_ds.map(map_fn, num_parallel_calls=AUTO)
331
training_ds = tf.data.Dataset.zip((train_img_ds, train_ray_ds))
332
train_ds = (
333
training_ds.shuffle(BATCH_SIZE)
334
.batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
335
.prefetch(AUTO)
336
)
337
338
# Make the validation pipeline.
339
val_img_ds = tf.data.Dataset.from_tensor_slices(val_images)
340
val_pose_ds = tf.data.Dataset.from_tensor_slices(val_poses)
341
val_ray_ds = val_pose_ds.map(map_fn, num_parallel_calls=AUTO)
342
validation_ds = tf.data.Dataset.zip((val_img_ds, val_ray_ds))
343
val_ds = (
344
validation_ds.shuffle(BATCH_SIZE)
345
.batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
346
.prefetch(AUTO)
347
)
348
349
"""
350
## NeRF model
351
352
The model is a multi-layer perceptron (MLP), with ReLU as its non-linearity.
353
354
An excerpt from the paper:
355
356
*"We encourage the representation to be multiview-consistent by
357
restricting the network to predict the volume density sigma as a
358
function of only the location `x`, while allowing the RGB color `c` to be
359
predicted as a function of both location and viewing direction. To
360
accomplish this, the MLP first processes the input 3D coordinate `x`
361
with 8 fully-connected layers (using ReLU activations and 256 channels
362
per layer), and outputs sigma and a 256-dimensional feature vector.
363
This feature vector is then concatenated with the camera ray's viewing
364
direction and passed to one additional fully-connected layer (using a
365
ReLU activation and 128 channels) that output the view-dependent RGB
366
color."*
367
368
Here we have gone for a minimal implementation and have used 64
369
Dense units instead of 256 as mentioned in the paper.
370
"""
371
372
373
def get_nerf_model(num_layers, num_pos):
374
"""Generates the NeRF neural network.
375
376
Args:
377
num_layers: The number of MLP layers.
378
num_pos: The number of dimensions of positional encoding.
379
380
Returns:
381
The `keras` model.
382
"""
383
inputs = keras.Input(shape=(num_pos, 2 * 3 * POS_ENCODE_DIMS + 3))
384
x = inputs
385
for i in range(num_layers):
386
x = layers.Dense(units=64, activation="relu")(x)
387
if i % 4 == 0 and i > 0:
388
# Inject residual connection.
389
x = layers.concatenate([x, inputs], axis=-1)
390
outputs = layers.Dense(units=4)(x)
391
return keras.Model(inputs=inputs, outputs=outputs)
392
393
394
def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
395
"""Generates the RGB image and depth map from model prediction.
396
397
Args:
398
model: The MLP model that is trained to predict the rgb and
399
volume density of the volumetric scene.
400
rays_flat: The flattened rays that serve as the input to
401
the NeRF model.
402
t_vals: The sample points for the rays.
403
rand: Choice to randomise the sampling strategy.
404
train: Whether the model is in the training or testing phase.
405
406
Returns:
407
Tuple of rgb image and depth map.
408
"""
409
# Get the predictions from the nerf model and reshape it.
410
if train:
411
predictions = model(rays_flat)
412
else:
413
predictions = model.predict(rays_flat)
414
predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
415
416
# Slice the predictions into rgb and sigma.
417
rgb = tf.sigmoid(predictions[..., :-1])
418
sigma_a = tf.nn.relu(predictions[..., -1])
419
420
# Get the distance of adjacent intervals.
421
delta = t_vals[..., 1:] - t_vals[..., :-1]
422
# delta shape = (num_samples)
423
if rand:
424
delta = tf.concat(
425
[delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
426
)
427
alpha = 1.0 - tf.exp(-sigma_a * delta)
428
else:
429
delta = tf.concat(
430
[delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
431
)
432
alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
433
434
# Get transmittance.
435
exp_term = 1.0 - alpha
436
epsilon = 1e-10
437
transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
438
weights = alpha * transmittance
439
rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
440
441
if rand:
442
depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
443
else:
444
depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
445
return (rgb, depth_map)
446
447
448
"""
449
## Training
450
451
The training step is implemented as part of a custom `keras.Model` subclass
452
so that we can make use of the `model.fit` functionality.
453
"""
454
455
456
class NeRF(keras.Model):
457
def __init__(self, nerf_model):
458
super().__init__()
459
self.nerf_model = nerf_model
460
461
def compile(self, optimizer, loss_fn):
462
super().compile()
463
self.optimizer = optimizer
464
self.loss_fn = loss_fn
465
self.loss_tracker = keras.metrics.Mean(name="loss")
466
self.psnr_metric = keras.metrics.Mean(name="psnr")
467
468
def train_step(self, inputs):
469
# Get the images and the rays.
470
(images, rays) = inputs
471
(rays_flat, t_vals) = rays
472
473
with tf.GradientTape() as tape:
474
# Get the predictions from the model.
475
rgb, _ = render_rgb_depth(
476
model=self.nerf_model, rays_flat=rays_flat, t_vals=t_vals, rand=True
477
)
478
loss = self.loss_fn(images, rgb)
479
480
# Get the trainable variables.
481
trainable_variables = self.nerf_model.trainable_variables
482
483
# Get the gradeints of the trainiable variables with respect to the loss.
484
gradients = tape.gradient(loss, trainable_variables)
485
486
# Apply the grads and optimize the model.
487
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
488
489
# Get the PSNR of the reconstructed images and the source images.
490
psnr = tf.image.psnr(images, rgb, max_val=1.0)
491
492
# Compute our own metrics
493
self.loss_tracker.update_state(loss)
494
self.psnr_metric.update_state(psnr)
495
return {"loss": self.loss_tracker.result(), "psnr": self.psnr_metric.result()}
496
497
def test_step(self, inputs):
498
# Get the images and the rays.
499
(images, rays) = inputs
500
(rays_flat, t_vals) = rays
501
502
# Get the predictions from the model.
503
rgb, _ = render_rgb_depth(
504
model=self.nerf_model, rays_flat=rays_flat, t_vals=t_vals, rand=True
505
)
506
loss = self.loss_fn(images, rgb)
507
508
# Get the PSNR of the reconstructed images and the source images.
509
psnr = tf.image.psnr(images, rgb, max_val=1.0)
510
511
# Compute our own metrics
512
self.loss_tracker.update_state(loss)
513
self.psnr_metric.update_state(psnr)
514
return {"loss": self.loss_tracker.result(), "psnr": self.psnr_metric.result()}
515
516
@property
517
def metrics(self):
518
return [self.loss_tracker, self.psnr_metric]
519
520
521
test_imgs, test_rays = next(iter(train_ds))
522
test_rays_flat, test_t_vals = test_rays
523
524
loss_list = []
525
526
527
class TrainMonitor(keras.callbacks.Callback):
528
def on_epoch_end(self, epoch, logs=None):
529
loss = logs["loss"]
530
loss_list.append(loss)
531
test_recons_images, depth_maps = render_rgb_depth(
532
model=self.model.nerf_model,
533
rays_flat=test_rays_flat,
534
t_vals=test_t_vals,
535
rand=True,
536
train=False,
537
)
538
539
# Plot the rgb, depth and the loss plot.
540
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 5))
541
ax[0].imshow(keras.utils.array_to_img(test_recons_images[0]))
542
ax[0].set_title(f"Predicted Image: {epoch:03d}")
543
544
ax[1].imshow(keras.utils.array_to_img(depth_maps[0, ..., None]))
545
ax[1].set_title(f"Depth Map: {epoch:03d}")
546
547
ax[2].plot(loss_list)
548
ax[2].set_xticks(np.arange(0, EPOCHS + 1, 5.0))
549
ax[2].set_title(f"Loss Plot: {epoch:03d}")
550
551
fig.savefig(f"images/{epoch:03d}.png")
552
plt.show()
553
plt.close()
554
555
556
num_pos = H * W * NUM_SAMPLES
557
nerf_model = get_nerf_model(num_layers=8, num_pos=num_pos)
558
559
model = NeRF(nerf_model)
560
model.compile(
561
optimizer=keras.optimizers.Adam(), loss_fn=keras.losses.MeanSquaredError()
562
)
563
564
# Create a directory to save the images during training.
565
if not os.path.exists("images"):
566
os.makedirs("images")
567
568
model.fit(
569
train_ds,
570
validation_data=val_ds,
571
batch_size=BATCH_SIZE,
572
epochs=EPOCHS,
573
callbacks=[TrainMonitor()],
574
)
575
576
577
def create_gif(path_to_images, name_gif):
578
filenames = glob.glob(path_to_images)
579
filenames = sorted(filenames)
580
images = []
581
for filename in tqdm(filenames):
582
images.append(imageio.imread(filename))
583
kargs = {"duration": 0.25}
584
imageio.mimsave(name_gif, images, "GIF", **kargs)
585
586
587
create_gif("images/*.png", "training.gif")
588
589
"""
590
## Visualize the training step
591
592
Here we see the training step. With the decreasing loss, the rendered
593
image and the depth maps are getting better. In your local system, you
594
will see the `training.gif` file generated.
595
596
![training-20](https://i.imgur.com/ql5OcYA.gif)
597
"""
598
599
"""
600
## Inference
601
602
In this section, we ask the model to build novel views of the scene.
603
The model was given `106` views of the scene in the training step. The
604
collections of training images cannot contain each and every angle of
605
the scene. A trained model can represent the entire 3-D scene with a
606
sparse set of training images.
607
608
Here we provide different poses to the model and ask for it to give us
609
the 2-D image corresponding to that camera view. If we infer the model
610
for all the 360-degree views, it should provide an overview of the
611
entire scenery from all around.
612
"""
613
614
# Get the trained NeRF model and infer.
615
nerf_model = model.nerf_model
616
test_recons_images, depth_maps = render_rgb_depth(
617
model=nerf_model,
618
rays_flat=test_rays_flat,
619
t_vals=test_t_vals,
620
rand=True,
621
train=False,
622
)
623
624
# Create subplots.
625
fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(10, 20))
626
627
for ax, ori_img, recons_img, depth_map in zip(
628
axes, test_imgs, test_recons_images, depth_maps
629
):
630
ax[0].imshow(keras.utils.array_to_img(ori_img))
631
ax[0].set_title("Original")
632
633
ax[1].imshow(keras.utils.array_to_img(recons_img))
634
ax[1].set_title("Reconstructed")
635
636
ax[2].imshow(keras.utils.array_to_img(depth_map[..., None]), cmap="inferno")
637
ax[2].set_title("Depth Map")
638
639
"""
640
## Render 3D Scene
641
642
Here we will synthesize novel 3D views and stitch all of them together
643
to render a video encompassing the 360-degree view.
644
"""
645
646
647
def get_translation_t(t):
648
"""Get the translation matrix for movement in t."""
649
matrix = [
650
[1, 0, 0, 0],
651
[0, 1, 0, 0],
652
[0, 0, 1, t],
653
[0, 0, 0, 1],
654
]
655
return tf.convert_to_tensor(matrix, dtype=tf.float32)
656
657
658
def get_rotation_phi(phi):
659
"""Get the rotation matrix for movement in phi."""
660
matrix = [
661
[1, 0, 0, 0],
662
[0, tf.cos(phi), -tf.sin(phi), 0],
663
[0, tf.sin(phi), tf.cos(phi), 0],
664
[0, 0, 0, 1],
665
]
666
return tf.convert_to_tensor(matrix, dtype=tf.float32)
667
668
669
def get_rotation_theta(theta):
670
"""Get the rotation matrix for movement in theta."""
671
matrix = [
672
[tf.cos(theta), 0, -tf.sin(theta), 0],
673
[0, 1, 0, 0],
674
[tf.sin(theta), 0, tf.cos(theta), 0],
675
[0, 0, 0, 1],
676
]
677
return tf.convert_to_tensor(matrix, dtype=tf.float32)
678
679
680
def pose_spherical(theta, phi, t):
681
"""
682
Get the camera to world matrix for the corresponding theta, phi
683
and t.
684
"""
685
c2w = get_translation_t(t)
686
c2w = get_rotation_phi(phi / 180.0 * np.pi) @ c2w
687
c2w = get_rotation_theta(theta / 180.0 * np.pi) @ c2w
688
c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
689
return c2w
690
691
692
rgb_frames = []
693
batch_flat = []
694
batch_t = []
695
696
# Iterate over different theta value and generate scenes.
697
for index, theta in tqdm(enumerate(np.linspace(0.0, 360.0, 120, endpoint=False))):
698
# Get the camera to world matrix.
699
c2w = pose_spherical(theta, -30.0, 4.0)
700
701
#
702
ray_oris, ray_dirs = get_rays(H, W, focal, c2w)
703
rays_flat, t_vals = render_flat_rays(
704
ray_oris, ray_dirs, near=2.0, far=6.0, num_samples=NUM_SAMPLES, rand=False
705
)
706
707
if index % BATCH_SIZE == 0 and index > 0:
708
batched_flat = tf.stack(batch_flat, axis=0)
709
batch_flat = [rays_flat]
710
711
batched_t = tf.stack(batch_t, axis=0)
712
batch_t = [t_vals]
713
714
rgb, _ = render_rgb_depth(
715
nerf_model, batched_flat, batched_t, rand=False, train=False
716
)
717
718
temp_rgb = [np.clip(255 * img, 0.0, 255.0).astype(np.uint8) for img in rgb]
719
720
rgb_frames = rgb_frames + temp_rgb
721
else:
722
batch_flat.append(rays_flat)
723
batch_t.append(t_vals)
724
725
rgb_video = "rgb_video.mp4"
726
imageio.mimwrite(rgb_video, rgb_frames, fps=30, quality=7, macro_block_size=None)
727
728
"""
729
### Visualize the video
730
731
Here we can see the rendered 360 degree view of the scene. The model
732
has successfully learned the entire volumetric space through the
733
sparse set of images in **only 20 epochs**. You can view the
734
rendered video saved locally, named `rgb_video.mp4`.
735
736
![rendered-video](https://i.imgur.com/j2sIkzW.gif)
737
"""
738
739
"""
740
## Conclusion
741
742
We have produced a minimal implementation of NeRF to provide an intuition of its
743
core ideas and methodology. This method has been used in various
744
other works in the computer graphics space.
745
746
We would like to encourage our readers to use this code as an example
747
and play with the hyperparameters and visualize the outputs. Below we
748
have also provided the outputs of the model trained for more epochs.
749
750
| Epochs | GIF of the training step |
751
| :--- | :---: |
752
| **100** | ![100-epoch-training](https://i.imgur.com/2k9p8ez.gif) |
753
| **200** | ![200-epoch-training](https://i.imgur.com/l3rG4HQ.gif) |
754
755
## Way forward
756
757
If anyone is interested to go deeper into NeRF, we have built a 3-part blog
758
series at [PyImageSearch](https://pyimagesearch.com/).
759
760
- [Prerequisites of NeRF](https://www.pyimagesearch.com/2021/11/10/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-1/)
761
- [Concepts of NeRF](https://www.pyimagesearch.com/2021/11/17/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-2/)
762
- [Implementing NeRF](https://www.pyimagesearch.com/2021/11/24/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-3/)
763
764
## Reference
765
766
- [NeRF repository](https://github.com/bmild/nerf): The official
767
repository for NeRF.
768
- [NeRF paper](https://arxiv.org/abs/2003.08934): The paper on NeRF.
769
- [Manim Repository](https://github.com/3b1b/manim): We have used
770
manim to build all the animations.
771
- [Mathworks](https://www.mathworks.com/help/vision/ug/camera-calibration.html):
772
Mathworks for the camera calibration article.
773
- [Mathew's video](https://www.youtube.com/watch?v=dPWLybp4LL0): A
774
great video on NeRF.
775
776
You can try the model on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/NeRF).
777
"""
778
779