Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/rpc_ddp_tutorial/main.py
Views: 713
import random12import torch3import torch.distributed as dist4import torch.distributed.autograd as dist_autograd5import torch.distributed.rpc as rpc6import torch.multiprocessing as mp7import torch.optim as optim8from torch.distributed.nn import RemoteModule9from torch.distributed.optim import DistributedOptimizer10from torch.distributed.rpc import RRef11from torch.distributed.rpc import TensorPipeRpcBackendOptions12from torch.nn.parallel import DistributedDataParallel as DDP1314NUM_EMBEDDINGS = 10015EMBEDDING_DIM = 161617# BEGIN hybrid_model18class HybridModel(torch.nn.Module):19r"""20The model consists of a sparse part and a dense part.211) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.222) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.23This remote model can get a Remote Reference to the embedding table on the parameter server.24"""2526def __init__(self, remote_emb_module, device):27super(HybridModel, self).__init__()28self.remote_emb_module = remote_emb_module29self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])30self.device = device3132def forward(self, indices, offsets):33emb_lookup = self.remote_emb_module.forward(indices, offsets)34return self.fc(emb_lookup.cuda(self.device))35# END hybrid_model3637# BEGIN setup_trainer38def _run_trainer(remote_emb_module, rank):39r"""40Each trainer runs a forward pass which involves an embedding lookup on the41parameter server and running nn.Linear locally. During the backward pass,42DDP is responsible for aggregating the gradients for the dense part43(nn.Linear) and distributed autograd ensures gradients updates are44propagated to the parameter server.45"""4647# Setup the model.48model = HybridModel(remote_emb_module, rank)4950# Retrieve all model parameters as rrefs for DistributedOptimizer.5152# Retrieve parameters for embedding table.53model_parameter_rrefs = model.remote_emb_module.remote_parameters()5455# model.fc.parameters() only includes local parameters.56# NOTE: Cannot call model.parameters() here,57# because this will call remote_emb_module.parameters(),58# which supports remote_parameters() but not parameters().59for param in model.fc.parameters():60model_parameter_rrefs.append(RRef(param))6162# Setup distributed optimizer63opt = DistributedOptimizer(64optim.SGD,65model_parameter_rrefs,66lr=0.05,67)6869criterion = torch.nn.CrossEntropyLoss()70# END setup_trainer7172# BEGIN run_trainer73def get_next_batch(rank):74for _ in range(10):75num_indices = random.randint(20, 50)76indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)7778# Generate offsets.79offsets = []80start = 081batch_size = 082while start < num_indices:83offsets.append(start)84start += random.randint(1, 10)85batch_size += 18687offsets_tensor = torch.LongTensor(offsets)88target = torch.LongTensor(batch_size).random_(8).cuda(rank)89yield indices, offsets_tensor, target9091# Train for 100 epochs92for epoch in range(100):93# create distributed autograd context94for indices, offsets, target in get_next_batch(rank):95with dist_autograd.context() as context_id:96output = model(indices, offsets)97loss = criterion(output, target)9899# Run distributed backward pass100dist_autograd.backward(context_id, [loss])101102# Tun distributed optimizer103opt.step(context_id)104105# Not necessary to zero grads as each iteration creates a different106# distributed autograd context which hosts different grads107print("Training done for epoch {}".format(epoch))108# END run_trainer109110# BEGIN run_worker111def run_worker(rank, world_size):112r"""113A wrapper function that initializes RPC, calls the function, and shuts down114RPC.115"""116117# We need to use different port numbers in TCP init_method for init_rpc and118# init_process_group to avoid port conflicts.119rpc_backend_options = TensorPipeRpcBackendOptions()120rpc_backend_options.init_method = "tcp://localhost:29501"121122# Rank 2 is master, 3 is ps and 0 and 1 are trainers.123if rank == 2:124rpc.init_rpc(125"master",126rank=rank,127world_size=world_size,128rpc_backend_options=rpc_backend_options,129)130131remote_emb_module = RemoteModule(132"ps",133torch.nn.EmbeddingBag,134args=(NUM_EMBEDDINGS, EMBEDDING_DIM),135kwargs={"mode": "sum"},136)137138# Run the training loop on trainers.139futs = []140for trainer_rank in [0, 1]:141trainer_name = "trainer{}".format(trainer_rank)142fut = rpc.rpc_async(143trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)144)145futs.append(fut)146147# Wait for all training to finish.148for fut in futs:149fut.wait()150elif rank <= 1:151# Initialize process group for Distributed DataParallel on trainers.152dist.init_process_group(153backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"154)155156# Initialize RPC.157trainer_name = "trainer{}".format(rank)158rpc.init_rpc(159trainer_name,160rank=rank,161world_size=world_size,162rpc_backend_options=rpc_backend_options,163)164165# Trainer just waits for RPCs from master.166else:167rpc.init_rpc(168"ps",169rank=rank,170world_size=world_size,171rpc_backend_options=rpc_backend_options,172)173# parameter server do nothing174pass175176# block until all rpcs finish177rpc.shutdown()178179180if __name__ == "__main__":181# 2 trainers, 1 parameter server, 1 master.182world_size = 4183mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)184# END run_worker185186187