Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132930 views
License: OTHER
1
import os
2
import sys
3
import tensorflow
4
import numpy as np
5
6
import matplotlib
7
matplotlib.use('TKAgg')
8
from matplotlib import pyplot as plt
9
10
from tensorflow.examples.tutorials.mnist import input_data
11
12
mnist_image_shape = [28, 28, 1]
13
14
def load_dataset():
15
return input_data.read_data_sets('MNIST_data')
16
17
def get_next_batch(dataset, batch_size):
18
# dataset should be mnist.(train/val/test)
19
batch, _ = dataset.next_batch(batch_size)
20
batch_shape = [batch_size] + mnist_image_shape
21
return np.reshape(batch, batch_shape)
22
23
def visualize(_original, _reconstructions, num_visualize):
24
vis_folder = './vis/'
25
if not os.path.exists(vis_folder):
26
os.makedirs(vis_folder)
27
28
original = _original[:num_visualize]
29
reconstructions = _reconstructions[:num_visualize]
30
31
count = 1
32
for (orig, rec) in zip(original, reconstructions):
33
orig = np.reshape(orig, (mnist_image_shape[0],
34
mnist_image_shape[1]))
35
rec = np.reshape(rec, (mnist_image_shape[0],
36
mnist_image_shape[1]))
37
f, ax = plt.subplots(1,2)
38
ax[0].imshow(orig, cmap='gray')
39
ax[1].imshow(rec, cmap='gray')
40
plt.savefig(vis_folder + "test_%d.png" % count)
41
count += 1
42
43