Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132932 views
License: OTHER
1
""" This file contains different utility functions that are not connected
2
in anyway to the networks presented in the tutorials, but rather help in
3
processing the outputs into a more understandable way.
4
5
For example ``tile_raster_images`` helps in generating a easy to grasp
6
image from a set of samples or weights.
7
"""
8
9
10
import numpy
11
from six.moves import xrange
12
13
14
def scale_to_unit_interval(ndar, eps=1e-8):
15
""" Scales all values in the ndarray ndar to be between 0 and 1 """
16
ndar = ndar.copy()
17
ndar -= ndar.min()
18
ndar *= 1.0 / (ndar.max() + eps)
19
return ndar
20
21
22
def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),
23
scale_rows_to_unit_interval=True,
24
output_pixel_vals=True):
25
"""
26
Transform an array with one flattened image per row, into an array in
27
which images are reshaped and layed out like tiles on a floor.
28
29
This function is useful for visualizing datasets whose rows are images,
30
and also columns of matrices for transforming those rows
31
(such as the first layer of a neural net).
32
33
:type X: a 2-D ndarray or a tuple of 4 channels, elements of which can
34
be 2-D ndarrays or None;
35
:param X: a 2-D array in which every row is a flattened image.
36
37
:type img_shape: tuple; (height, width)
38
:param img_shape: the original shape of each image
39
40
:type tile_shape: tuple; (rows, cols)
41
:param tile_shape: the number of images to tile (rows, cols)
42
43
:param output_pixel_vals: if output should be pixel values (i.e. int8
44
values) or floats
45
46
:param scale_rows_to_unit_interval: if the values need to be scaled before
47
being plotted to [0,1] or not
48
49
50
:returns: array suitable for viewing as an image.
51
(See:`Image.fromarray`.)
52
:rtype: a 2-d array with same dtype as X.
53
54
"""
55
56
assert len(img_shape) == 2
57
assert len(tile_shape) == 2
58
assert len(tile_spacing) == 2
59
60
# The expression below can be re-written in a more C style as
61
# follows :
62
#
63
# out_shape = [0,0]
64
# out_shape[0] = (img_shape[0]+tile_spacing[0])*tile_shape[0] -
65
# tile_spacing[0]
66
# out_shape[1] = (img_shape[1]+tile_spacing[1])*tile_shape[1] -
67
# tile_spacing[1]
68
out_shape = [
69
(ishp + tsp) * tshp - tsp
70
for ishp, tshp, tsp in zip(img_shape, tile_shape, tile_spacing)
71
]
72
73
if isinstance(X, tuple):
74
assert len(X) == 4
75
# Create an output numpy ndarray to store the image
76
if output_pixel_vals:
77
out_array = numpy.zeros((out_shape[0], out_shape[1], 4),
78
dtype='uint8')
79
else:
80
out_array = numpy.zeros((out_shape[0], out_shape[1], 4),
81
dtype=X.dtype)
82
83
#colors default to 0, alpha defaults to 1 (opaque)
84
if output_pixel_vals:
85
channel_defaults = [0, 0, 0, 255]
86
else:
87
channel_defaults = [0., 0., 0., 1.]
88
89
for i in xrange(4):
90
if X[i] is None:
91
# if channel is None, fill it with zeros of the correct
92
# dtype
93
dt = out_array.dtype
94
if output_pixel_vals:
95
dt = 'uint8'
96
out_array[:, :, i] = numpy.zeros(
97
out_shape,
98
dtype=dt
99
) + channel_defaults[i]
100
else:
101
# use a recurrent call to compute the channel and store it
102
# in the output
103
out_array[:, :, i] = tile_raster_images(
104
X[i], img_shape, tile_shape, tile_spacing,
105
scale_rows_to_unit_interval, output_pixel_vals)
106
return out_array
107
108
else:
109
# if we are dealing with only one channel
110
H, W = img_shape
111
Hs, Ws = tile_spacing
112
113
# generate a matrix to store the output
114
dt = X.dtype
115
if output_pixel_vals:
116
dt = 'uint8'
117
out_array = numpy.zeros(out_shape, dtype=dt)
118
119
for tile_row in xrange(tile_shape[0]):
120
for tile_col in xrange(tile_shape[1]):
121
if tile_row * tile_shape[1] + tile_col < X.shape[0]:
122
this_x = X[tile_row * tile_shape[1] + tile_col]
123
if scale_rows_to_unit_interval:
124
# if we should scale values to be between 0 and 1
125
# do this by calling the `scale_to_unit_interval`
126
# function
127
this_img = scale_to_unit_interval(
128
this_x.reshape(img_shape))
129
else:
130
this_img = this_x.reshape(img_shape)
131
# add the slice to the corresponding position in the
132
# output array
133
c = 1
134
if output_pixel_vals:
135
c = 255
136
out_array[
137
tile_row * (H + Hs): tile_row * (H + Hs) + H,
138
tile_col * (W + Ws): tile_col * (W + Ws) + W
139
] = this_img * c
140
return out_array
141
142