CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/changing_default_device.py
Views: 713
1
"""
2
Changing default device
3
=======================
4
5
It is common practice to write PyTorch code in a device-agnostic way,
6
and then switch between CPU and CUDA depending on what hardware is available.
7
Typically, to do this you might have used if-statements and ``cuda()`` calls
8
to do this:
9
10
.. note::
11
This recipe requires PyTorch 2.0.0 or later.
12
13
"""
14
import torch
15
16
USE_CUDA = False
17
18
mod = torch.nn.Linear(20, 30)
19
if USE_CUDA:
20
mod.cuda()
21
22
device = 'cpu'
23
if USE_CUDA:
24
device = 'cuda'
25
inp = torch.randn(128, 20, device=device)
26
print(mod(inp).device)
27
28
###################################################################
29
# PyTorch now also has a context manager which can take care of the
30
# device transfer automatically. Here is an example:
31
32
with torch.device('cuda'):
33
mod = torch.nn.Linear(20, 30)
34
print(mod.weight.device)
35
print(mod(torch.randn(128, 20)).device)
36
37
#########################################
38
# You can also set it globally like this:
39
40
torch.set_default_device('cuda')
41
42
mod = torch.nn.Linear(20, 30)
43
print(mod.weight.device)
44
print(mod(torch.randn(128, 20)).device)
45
46
################################################################
47
# This function imposes a slight performance cost on every Python
48
# call to the torch API (not just factory functions). If this
49
# is causing problems for you, please comment on
50
# `this issue <https://github.com/pytorch/pytorch/issues/92701>`__
51
52