1 Star 0 Fork 3

jimhua/Horovod

forked from Gitee 极速下载/Horovod 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
pytorch.rst 4.81 KB
一键复制 编辑 原始数据 按行查看 历史
redwrasse 提交于 2021-07-12 14:48 . Update pytorch.rst (#3037)

Horovod with PyTorch

To use Horovod with PyTorch, make the following modifications to your training script:

  1. Run hvd.init().

  1. Pin each GPU to a single process.

    With the typical setup of one GPU per process, set this to local rank. The first process on the server will be allocated the first GPU, the second process will be allocated the second GPU, and so forth.

    if torch.cuda.is_available():
        torch.cuda.set_device(hvd.local_rank())
    

  1. Scale the learning rate by the number of workers.

    Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.

  1. Wrap the optimizer in hvd.DistributedOptimizer.

    The distributed optimizer delegates gradient computation to the original optimizer, averages gradients using allreduce or allgather, and then applies those averaged gradients.

  1. Broadcast the initial variable states from rank 0 to all other processes:

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    

    This is necessary to ensure consistent initialization of all workers when training is started with random weights or restored from a checkpoint.

  1. Modify your code to save checkpoints only on worker 0 to prevent other workers from corrupting them.

    Accomplish this by guarding model checkpointing code with hvd.rank() != 0.

Example (also see a full training example):

import torch
import horovod.torch as hvd

# Initialize Horovod
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process)
torch.cuda.set_device(hvd.local_rank())

# Define dataset...
train_dataset = ...

# Partition dataset among workers using DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

# Build model...
model = ...
model.cuda()

optimizer = optim.SGD(model.parameters())

# Add Horovod Distributed Optimizer
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

# Broadcast parameters from rank 0 to all other processes.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

for epoch in range(100):
   for batch_idx, (data, target) in enumerate(train_loader):
       optimizer.zero_grad()
       output = model(data)
       loss = F.nll_loss(output, target)
       loss.backward()
       optimizer.step()
       if batch_idx % args.log_interval == 0:
           print('Train Epoch: {} [{}/{}]\tLoss: {}'.format(
               epoch, batch_idx * len(data), len(train_sampler), loss.item()))

Note

PyTorch GPU support requires NCCL 2.2 or later. It also works with NCCL 2.1.15 if you are not using RoCE or InfiniBand.

PyTorch Lightning

Horovod is supported as a distributed backend in PyTorch Lightning from v0.7.4 and above.

With PyTorch Lightning, distributed training using Horovod requires only a single line code change to your existing training script:

# train Horovod on GPU (number of GPUs / machines provided on command-line)
trainer = pl.Trainer(accelerator='horovod', gpus=1)

# train Horovod on CPU (number of processes / machines provided on command-line)
trainer = pl.Trainer(accelerator='horovod')

May need to change parameter "accelerator" name to "distributed_backend" in some older version of pytorch_lightning.

Start the training job and specify the number of workers on the command line as you normally would when using Horovod:

# run training with 4 GPUs on a single machine
$ horovodrun -np 4 python train.py

# run training with 8 GPUs on two machines (4 GPUs each)
$ horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py

You can find an example of using pytorch lightning trainer with horovod backend in pytorch_lightning_mnist.py

See the PyTorch Lightning docs for more details.

A Pytorch-Lightning based spark estimator is also added, example is in pytorch_lightning_spark_mnist.py

马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jimhua/Horovod.git
[email protected]:jimhua/Horovod.git
jimhua
Horovod
Horovod
master

搜索帮助