CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/rpc_ddp_tutorial/main.py
Views: 494
1
import random
2
3
import torch
4
import torch.distributed as dist
5
import torch.distributed.autograd as dist_autograd
6
import torch.distributed.rpc as rpc
7
import torch.multiprocessing as mp
8
import torch.optim as optim
9
from torch.distributed.nn import RemoteModule
10
from torch.distributed.optim import DistributedOptimizer
11
from torch.distributed.rpc import RRef
12
from torch.distributed.rpc import TensorPipeRpcBackendOptions
13
from torch.nn.parallel import DistributedDataParallel as DDP
14
15
NUM_EMBEDDINGS = 100
16
EMBEDDING_DIM = 16
17
18
# BEGIN hybrid_model
19
class HybridModel(torch.nn.Module):
20
r"""
21
The model consists of a sparse part and a dense part.
22
1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23
2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24
This remote model can get a Remote Reference to the embedding table on the parameter server.
25
"""
26
27
def __init__(self, remote_emb_module, device):
28
super(HybridModel, self).__init__()
29
self.remote_emb_module = remote_emb_module
30
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
31
self.device = device
32
33
def forward(self, indices, offsets):
34
emb_lookup = self.remote_emb_module.forward(indices, offsets)
35
return self.fc(emb_lookup.cuda(self.device))
36
# END hybrid_model
37
38
# BEGIN setup_trainer
39
def _run_trainer(remote_emb_module, rank):
40
r"""
41
Each trainer runs a forward pass which involves an embedding lookup on the
42
parameter server and running nn.Linear locally. During the backward pass,
43
DDP is responsible for aggregating the gradients for the dense part
44
(nn.Linear) and distributed autograd ensures gradients updates are
45
propagated to the parameter server.
46
"""
47
48
# Setup the model.
49
model = HybridModel(remote_emb_module, rank)
50
51
# Retrieve all model parameters as rrefs for DistributedOptimizer.
52
53
# Retrieve parameters for embedding table.
54
model_parameter_rrefs = model.remote_emb_module.remote_parameters()
55
56
# model.fc.parameters() only includes local parameters.
57
# NOTE: Cannot call model.parameters() here,
58
# because this will call remote_emb_module.parameters(),
59
# which supports remote_parameters() but not parameters().
60
for param in model.fc.parameters():
61
model_parameter_rrefs.append(RRef(param))
62
63
# Setup distributed optimizer
64
opt = DistributedOptimizer(
65
optim.SGD,
66
model_parameter_rrefs,
67
lr=0.05,
68
)
69
70
criterion = torch.nn.CrossEntropyLoss()
71
# END setup_trainer
72
73
# BEGIN run_trainer
74
def get_next_batch(rank):
75
for _ in range(10):
76
num_indices = random.randint(20, 50)
77
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
78
79
# Generate offsets.
80
offsets = []
81
start = 0
82
batch_size = 0
83
while start < num_indices:
84
offsets.append(start)
85
start += random.randint(1, 10)
86
batch_size += 1
87
88
offsets_tensor = torch.LongTensor(offsets)
89
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
90
yield indices, offsets_tensor, target
91
92
# Train for 100 epochs
93
for epoch in range(100):
94
# create distributed autograd context
95
for indices, offsets, target in get_next_batch(rank):
96
with dist_autograd.context() as context_id:
97
output = model(indices, offsets)
98
loss = criterion(output, target)
99
100
# Run distributed backward pass
101
dist_autograd.backward(context_id, [loss])
102
103
# Tun distributed optimizer
104
opt.step(context_id)
105
106
# Not necessary to zero grads as each iteration creates a different
107
# distributed autograd context which hosts different grads
108
print("Training done for epoch {}".format(epoch))
109
# END run_trainer
110
111
# BEGIN run_worker
112
def run_worker(rank, world_size):
113
r"""
114
A wrapper function that initializes RPC, calls the function, and shuts down
115
RPC.
116
"""
117
118
# We need to use different port numbers in TCP init_method for init_rpc and
119
# init_process_group to avoid port conflicts.
120
rpc_backend_options = TensorPipeRpcBackendOptions()
121
rpc_backend_options.init_method = "tcp://localhost:29501"
122
123
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
124
if rank == 2:
125
rpc.init_rpc(
126
"master",
127
rank=rank,
128
world_size=world_size,
129
rpc_backend_options=rpc_backend_options,
130
)
131
132
remote_emb_module = RemoteModule(
133
"ps",
134
torch.nn.EmbeddingBag,
135
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
136
kwargs={"mode": "sum"},
137
)
138
139
# Run the training loop on trainers.
140
futs = []
141
for trainer_rank in [0, 1]:
142
trainer_name = "trainer{}".format(trainer_rank)
143
fut = rpc.rpc_async(
144
trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
145
)
146
futs.append(fut)
147
148
# Wait for all training to finish.
149
for fut in futs:
150
fut.wait()
151
elif rank <= 1:
152
# Initialize process group for Distributed DataParallel on trainers.
153
dist.init_process_group(
154
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
155
)
156
157
# Initialize RPC.
158
trainer_name = "trainer{}".format(rank)
159
rpc.init_rpc(
160
trainer_name,
161
rank=rank,
162
world_size=world_size,
163
rpc_backend_options=rpc_backend_options,
164
)
165
166
# Trainer just waits for RPCs from master.
167
else:
168
rpc.init_rpc(
169
"ps",
170
rank=rank,
171
world_size=world_size,
172
rpc_backend_options=rpc_backend_options,
173
)
174
# parameter server do nothing
175
pass
176
177
# block until all rpcs finish
178
rpc.shutdown()
179
180
181
if __name__ == "__main__":
182
# 2 trainers, 1 parameter server, 1 master.
183
world_size = 4
184
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
185
# END run_worker
186
187