CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/memory_format_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
(beta) Channels Last Memory Format in PyTorch
4
*******************************************************
5
**Author**: `Vitaly Fedyunin <https://github.com/VitalyFedyunin>`_
6
7
What is Channels Last
8
---------------------
9
10
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
11
12
For example, classic (contiguous) storage of NCHW tensor (in our case it is two 4x4 images with 3 color channels) look like this:
13
14
.. figure:: /_static/img/classic_memory_format.png
15
:alt: classic_memory_format
16
17
Channels last memory format orders data differently:
18
19
.. figure:: /_static/img/channels_last_memory_format.png
20
:alt: channels_last_memory_format
21
22
Pytorch supports memory formats (and provides back compatibility with existing models including eager, JIT, and TorchScript) by utilizing existing strides structure.
23
For example, 10x3x16x16 batch in Channels last format will have strides equal to (768, 1, 48, 3).
24
"""
25
26
######################################################################
27
# Channels last memory format is implemented for 4D NCHW Tensors only.
28
#
29
30
######################################################################
31
# Memory Format API
32
# -----------------------
33
#
34
# Here is how to convert tensors between contiguous and channels
35
# last memory formats.
36
37
######################################################################
38
# Classic PyTorch contiguous tensor
39
import torch
40
41
N, C, H, W = 10, 3, 32, 32
42
x = torch.empty(N, C, H, W)
43
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
44
45
######################################################################
46
# Conversion operator
47
x = x.to(memory_format=torch.channels_last)
48
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
49
print(x.stride()) # Outputs: (3072, 1, 96, 3)
50
51
######################################################################
52
# Back to contiguous
53
x = x.to(memory_format=torch.contiguous_format)
54
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
55
56
######################################################################
57
# Alternative option
58
x = x.contiguous(memory_format=torch.channels_last)
59
print(x.stride()) # Outputs: (3072, 1, 96, 3)
60
61
######################################################################
62
# Format checks
63
print(x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
64
65
######################################################################
66
# There are minor difference between the two APIs ``to`` and
67
# ``contiguous``. We suggest to stick with ``to`` when explicitly
68
# converting memory format of tensor.
69
#
70
# For general cases the two APIs behave the same. However in special
71
# cases for a 4D tensor with size ``NCHW`` when either: ``C==1`` or
72
# ``H==1 && W==1``, only ``to`` would generate a proper stride to
73
# represent channels last memory format.
74
#
75
# This is because in either of the two cases above, the memory format
76
# of a tensor is ambiguous, i.e. a contiguous tensor with size
77
# ``N1HW`` is both ``contiguous`` and channels last in memory storage.
78
# Therefore, they are already considered as ``is_contiguous``
79
# for the given memory format and hence ``contiguous`` call becomes a
80
# no-op and would not update the stride. On the contrary, ``to``
81
# would restride tensor with a meaningful stride on dimensions whose
82
# sizes are 1 in order to properly represent the intended memory
83
# format
84
special_x = torch.empty(4, 1, 4, 4)
85
print(special_x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
86
print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Outputs: True
87
88
######################################################################
89
# Same thing applies to explicit permutation API ``permute``. In
90
# special case where ambiguity could occur, ``permute`` does not
91
# guarantee to produce a stride that properly carry the intended
92
# memory format. We suggest to use ``to`` with explicit memory format
93
# to avoid unintended behavior.
94
#
95
# And a side note that in the extreme case, where three non-batch
96
# dimensions are all equal to ``1`` (``C==1 && H==1 && W==1``),
97
# current implementation cannot mark a tensor as channels last memory
98
# format.
99
100
######################################################################
101
# Create as channels last
102
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
103
print(x.stride()) # Outputs: (3072, 1, 96, 3)
104
105
######################################################################
106
# ``clone`` preserves memory format
107
y = x.clone()
108
print(y.stride()) # Outputs: (3072, 1, 96, 3)
109
110
######################################################################
111
# ``to``, ``cuda``, ``float`` ... preserves memory format
112
if torch.cuda.is_available():
113
y = x.cuda()
114
print(y.stride()) # Outputs: (3072, 1, 96, 3)
115
116
######################################################################
117
# ``empty_like``, ``*_like`` operators preserves memory format
118
y = torch.empty_like(x)
119
print(y.stride()) # Outputs: (3072, 1, 96, 3)
120
121
######################################################################
122
# Pointwise operators preserves memory format
123
z = x + y
124
print(z.stride()) # Outputs: (3072, 1, 96, 3)
125
126
######################################################################
127
# ``Conv``, ``Batchnorm`` modules using ``cudnn`` backends support channels last
128
# (only works for cuDNN >= 7.6). Convolution modules, unlike binary
129
# p-wise operator, have channels last as the dominating memory format.
130
# If all inputs are in contiguous memory format, the operator
131
# produces output in contiguous memory format. Otherwise, output will
132
# be in channels last memory format.
133
134
if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
135
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
136
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
137
138
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
139
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
140
141
out = model(input)
142
print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
143
144
######################################################################
145
# When input tensor reaches a operator without channels last support,
146
# a permutation should automatically apply in the kernel to restore
147
# contiguous on input tensor. This introduces overhead and stops the
148
# channels last memory format propagation. Nevertheless, it guarantees
149
# correct output.
150
151
######################################################################
152
# Performance Gains
153
# --------------------------------------------------------------------
154
# Channels last memory format optimizations are available on both GPU and CPU.
155
# On GPU, the most significant performance gains are observed on NVIDIA's
156
# hardware with Tensor Cores support running on reduced precision
157
# (``torch.float16``).
158
# We were able to archive over 22% performance gains with channels last
159
# comparing to contiguous format, both while utilizing
160
# 'AMP (Automated Mixed Precision)' training scripts.
161
# Our scripts uses AMP supplied by NVIDIA
162
# https://github.com/NVIDIA/apex.
163
#
164
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data``
165
166
# opt_level = O2
167
# keep_batchnorm_fp32 = None <class 'NoneType'>
168
# loss_scale = None <class 'NoneType'>
169
# CUDNN VERSION: 7603
170
# => creating model 'resnet50'
171
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
172
# Defaults for this optimization level are:
173
# enabled : True
174
# opt_level : O2
175
# cast_model_type : torch.float16
176
# patch_torch_functions : False
177
# keep_batchnorm_fp32 : True
178
# master_weights : True
179
# loss_scale : dynamic
180
# Processing user overrides (additional kwargs that are not None)...
181
# After processing overrides, optimization options are:
182
# enabled : True
183
# opt_level : O2
184
# cast_model_type : torch.float16
185
# patch_torch_functions : False
186
# keep_batchnorm_fp32 : True
187
# master_weights : True
188
# loss_scale : dynamic
189
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
190
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
191
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
192
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
193
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
194
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
195
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
196
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)
197
198
######################################################################
199
# Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% performance gain.
200
#
201
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data``
202
203
# opt_level = O2
204
# keep_batchnorm_fp32 = None <class 'NoneType'>
205
# loss_scale = None <class 'NoneType'>
206
#
207
# CUDNN VERSION: 7603
208
#
209
# => creating model 'resnet50'
210
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
211
#
212
# Defaults for this optimization level are:
213
# enabled : True
214
# opt_level : O2
215
# cast_model_type : torch.float16
216
# patch_torch_functions : False
217
# keep_batchnorm_fp32 : True
218
# master_weights : True
219
# loss_scale : dynamic
220
# Processing user overrides (additional kwargs that are not None)...
221
# After processing overrides, optimization options are:
222
# enabled : True
223
# opt_level : O2
224
# cast_model_type : torch.float16
225
# patch_torch_functions : False
226
# keep_batchnorm_fp32 : True
227
# master_weights : True
228
# loss_scale : dynamic
229
#
230
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
231
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
232
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
233
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
234
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
235
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
236
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
237
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)
238
239
######################################################################
240
# The following list of models has the full support of Channels last and showing 8%-35% performance gains on Volta devices:
241
# ``alexnet``, ``mnasnet0_5``, ``mnasnet0_75``, ``mnasnet1_0``, ``mnasnet1_3``, ``mobilenet_v2``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``shufflenet_v2_x1_5``, ``shufflenet_v2_x2_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2``
242
#
243
244
######################################################################
245
# The following list of models has the full support of Channels last and showing 26%-76% performance gains on Intel(R) Xeon(R) Ice Lake (or newer) CPUs:
246
# ``alexnet``, ``densenet121``, ``densenet161``, ``densenet169``, ``googlenet``, ``inception_v3``, ``mnasnet0_5``, ``mnasnet1_0``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext101_32x8d``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2``
247
#
248
249
######################################################################
250
# Converting existing models
251
# --------------------------
252
#
253
# Channels last support is not limited by existing models, as any
254
# model can be converted to channels last and propagate format through
255
# the graph as soon as input (or certain weight) is formatted
256
# correctly.
257
#
258
259
# Need to be done once, after model initialization (or load)
260
model = model.to(memory_format=torch.channels_last) # Replace with your model
261
262
# Need to be done for every input
263
input = input.to(memory_format=torch.channels_last) # Replace with your input
264
output = model(input)
265
266
#######################################################################
267
# However, not all operators fully converted to support channels last
268
# (usually returning contiguous output instead). In the example posted
269
# above, layers that does not support channels last will stop the
270
# memory format propagation. In spite of that, as we have converted the
271
# model to channels last format, that means each convolution layer,
272
# which has its 4 dimensional weight in channels last memory format,
273
# will restore channels last memory format and benefit from faster
274
# kernels.
275
#
276
# But operators that does not support channels last does introduce
277
# overhead by permutation. Optionally, you can investigate and identify
278
# operators in your model that does not support channels last, if you
279
# want to improve the performance of converted model.
280
#
281
# That means you need to verify the list of used operators
282
# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
283
# or introduce memory format checks into eager execution mode and run your model.
284
#
285
# After running the code below, operators will raise an exception if the output of the
286
# operator doesn't match the memory format of the input.
287
#
288
#
289
def contains_cl(args):
290
for t in args:
291
if isinstance(t, torch.Tensor):
292
if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
293
return True
294
elif isinstance(t, list) or isinstance(t, tuple):
295
if contains_cl(list(t)):
296
return True
297
return False
298
299
300
def print_inputs(args, indent=""):
301
for t in args:
302
if isinstance(t, torch.Tensor):
303
print(indent, t.stride(), t.shape, t.device, t.dtype)
304
elif isinstance(t, list) or isinstance(t, tuple):
305
print(indent, type(t))
306
print_inputs(list(t), indent=indent + " ")
307
else:
308
print(indent, t)
309
310
311
def check_wrapper(fn):
312
name = fn.__name__
313
314
def check_cl(*args, **kwargs):
315
was_cl = contains_cl(args)
316
try:
317
result = fn(*args, **kwargs)
318
except Exception as e:
319
print("`{}` inputs are:".format(name))
320
print_inputs(args)
321
print("-------------------")
322
raise e
323
failed = False
324
if was_cl:
325
if isinstance(result, torch.Tensor):
326
if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
327
print(
328
"`{}` got channels_last input, but output is not channels_last:".format(name),
329
result.shape,
330
result.stride(),
331
result.device,
332
result.dtype,
333
)
334
failed = True
335
if failed and True:
336
print("`{}` inputs are:".format(name))
337
print_inputs(args)
338
raise Exception("Operator `{}` lost channels_last property".format(name))
339
return result
340
341
return check_cl
342
343
344
old_attrs = dict()
345
346
347
def attribute(m):
348
old_attrs[m] = dict()
349
for i in dir(m):
350
e = getattr(m, i)
351
exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
352
if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
353
try:
354
old_attrs[m][i] = e
355
setattr(m, i, check_wrapper(e))
356
except Exception as e:
357
print(i)
358
print(e)
359
360
361
attribute(torch.Tensor)
362
attribute(torch.nn.functional)
363
attribute(torch)
364
365
366
######################################################################
367
# If you found an operator that doesn't support channels last tensors
368
# and you want to contribute, feel free to use following developers
369
# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
370
#
371
372
######################################################################
373
# Code below is to recover the attributes of torch.
374
375
for (m, attrs) in old_attrs.items():
376
for (k, v) in attrs.items():
377
setattr(m, k, v)
378
379
######################################################################
380
# Work to do
381
# ----------
382
# There are still many things to do, such as:
383
#
384
# - Resolving ambiguity of ``N1HW`` and ``NC11`` Tensors;
385
# - Testing of Distributed Training support;
386
# - Improving operators coverage.
387
#
388
# If you have feedback and/or suggestions for improvement, please let us
389
# know by creating `an issue <https://github.com/pytorch/pytorch/issues>`_.
390
391