Pytorch多GPU数据并行训练

代码来源于 https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/train_multi_GPU
up主的讲解视频 在pytorch框架下使用多卡(多GPU)进行并行训练

修改了模型和数据部分,作测试,仅记录。
多GPU数据并行训练主要包括以下方面:

  1. 数据
 train_sampler = torch.utils.data.distributed.DistributedSampler(train_data_set)
 val_sampler = torch.utils.data.distributed.DistributedSampler(val_data_set)

2.模型

 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

3.返回结果loss,中间值

torch.distributed.all_reduce(value)
#BN层
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)

完整代码

# -*- coding:utf-8 -*-
import math
import os
import sys
import argparse
import torch
from tqdm import tqdm
import torch.optim as optim
from torchvision.datasets import mnist
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def init_distributed_mode(args):
    # 检查环境变量 RANK 和 WORLD_SIZE
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ['RANK'])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ.get('LOCAL_RANK', 0))  # 使用默认值 0
    # 检查环境变量 SLURM_PROCID
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    # 设置 GPU
    torch.cuda.set_device(args.gpu)

    # 设置通信后端
    args.dist_backend = 'nccl'  # 通信后端,nvidia GPU推荐使用NCCL

    # 初始化分布式环境
    print(f'| distributed init (rank {args.rank}): {args.dist_url}', flush=True)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)
    dist.barrier()


def cleanup():
    """
    清理函数,用于销毁进程组。
    """
    dist.destroy_process_group()
    

def is_dist_avail_and_initialized():
    """检查是否支持分布式环境"""
    # 检查是否支持分布式环境
    if not dist.is_available():
        return False
    # 检查是否已初始化分布式环境
    if not dist.is_initialized():
        return False
    return True
    
def get_world_size():
    # 检查分布式是否可用并已初始化
    if not is_dist_avail_and_initialized():
        return 1
    # 获取分布式大小
    return dist.get_world_size()

def get_rank():
    # 检查分布式环境是否可用并已初始化
    if not is_dist_avail_and_initialized():
        return 0
    # 获取当前进程的分布式排名
    return dist.get_rank()

def is_main_process():
    """
    判断当前进程是否为主进程
    """
    return get_rank() == 0

def reduce_value(value, average=True):
    # 获取当前进程的数量
    world_size = get_world_size()
    # 如果进程数量小于2,表示单GPU的情况,直接返回value
    if world_size < 2: 
        return value

    # 在不计算梯度的情况下,将value进行所有进程的求和操作
    with torch.no_grad():
        # 使用分布式训练库进行所有进程的求和操作
        dist.all_reduce(value)
        # 如果average为True,则将value除以进程数量,得到平均值
        if average:
            value /= world_size

    return value


# 定义模型
class CNNNet(torch.nn.Module):

    def __init__(self, in_channel, out_channel_one, out_channel_two, out_channel_three, fc_1, fc_2, fc_out):
        super(CNNNet, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=in_channel, out_channels=out_channel_one, kernel_size=5, stride=1, padding=1)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2,padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=out_channel_one, out_channels=out_channel_two, kernel_size=5, stride=1,padding=1)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2,padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=out_channel_two,out_channels=out_channel_three, kernel_size=5, stride=1,padding=1)

        self.fc1 = torch.nn.Linear(5*5*32, fc_1)
        self.fc2 = torch.nn.Linear(fc_1, fc_2)
        self.output = torch.nn.Linear(fc_2, fc_out)

    def forward(self, x):
        x = self.pool1(torch.nn.functional.relu(self.conv1(x)))
        x = self.pool2(torch.nn.functional.relu(self.conv2(x)))
        x = torch.nn.functional.relu(self.conv3(x))

        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.softmax(self.output(x), dim=1)

        return x


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    # 设置模型为训练模式
    model.train()
    # 定义交叉熵损失函数
    loss_function = torch.nn.CrossEntropyLoss()
    # 初始化平均损失为0
    mean_loss = torch.zeros(1).to(device)
    # 清空优化器的梯度
    optimizer.zero_grad()

    # 在进程0中打印训练进度
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)

    # 遍历数据加载器中的每个步骤
    for step, data in enumerate(data_loader):
        # 获取图像和标签
        images, labels = data

        # 使用模型进行预测
        pred = model(images.to(device))

        # 计算损失
        loss = loss_function(pred, labels.to(device))
        # 反向传播计算梯度
        loss.backward()
        # 对损失进行平均
        loss = reduce_value(loss, average=True)
        # 更新平均损失
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)

        # 在进程0中打印平均损失
        if is_main_process():
            data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))

        # 检查损失是否为非有限值
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        # 更新优化器参数
        optimizer.step()
        # 清空优化器的梯度
        optimizer.zero_grad()

    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    # 返回平均损失
    return mean_loss.item()


def evaluate(model, data_loader, device):
    # 将模型设置为评估模式
    model.eval()

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)

    # 在进程0中打印验证进度
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)

    # 遍历数据加载器中的每个批次
    for step, data in enumerate(data_loader):
        # 获取图像和标签
        images, labels = data
        # 使用模型进行预测
        pred = model(images.to(device))
        # 获取预测结果中的最大值
        pred = torch.max(pred, dim=1)[1]
        # 统计预测正确的样本个数
        sum_num += torch.eq(pred, labels.to(device)).sum()

    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    # 将结果进行归一化处理
    sum_num = reduce_value(sum_num, average=False)

    # 返回预测正确的样本个数
    return sum_num.item()


def main(args):
    if not torch.cuda.is_available() :
        raise EnvironmentError("not find GPU device for training.")

    # 初始化各进程环境
    init_distributed_mode(args=args)

    rank = args.rank
    device = torch.device(args.device)
    batch_size = args.batch_size
    weights_path = args.weights
    args.lr *= args.world_size  # 学习率要根据并行GPU的数量进行倍增
    checkpoint_path = ""

    if rank == 0:  # 在第一个进程中打印信息,并实例化tensorboard
        print(args)
        print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
        tb_writer = SummaryWriter()
        if os.path.exists("./weights") is False:
            os.makedirs("./weights")

    # train_info, val_info, num_classes = read_split_data(args.data_path)
    # train_images_path, train_images_label = train_info
    # val_images_path, val_images_label = val_info

    # check num_classes
    # assert args.num_classes == num_classes, "dataset num_classes: {}, input {}".format(args.num_classes,num_classes)

    data_transform = {
        "train": transforms.Compose([transforms.Resize(28),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=(0.1307,), std=(0.3081,))]),
        "val": transforms.Compose([transforms.Resize(28),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=(0.1307,), std=(0.3081,))])}

    train_data_set = mnist.MNIST('./MNIST', train=True, transform=data_transform['train'], download=True)
    val_data_set = mnist.MNIST("./MNIST", train=False, transform=data_transform['val'], download=True)
    
    # 给每个rank对应的进程分配训练的样本索引
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data_set)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_data_set)

    # 将样本索引每batch_size个元素组成一个list
    train_batch_sampler = torch.utils.data.BatchSampler(
        train_sampler, batch_size, drop_last=True)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    if rank == 0:
        print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_sampler=train_batch_sampler,
                                               pin_memory=True,
                                               num_workers=nw,)

    val_loader = torch.utils.data.DataLoader(val_data_set,
                                             batch_size=batch_size,
                                             sampler=val_sampler,
                                             pin_memory=True,
                                             num_workers=nw,)
    # 实例化模型
    model = CNNNet(1, 16, 32, 32, 128, 64, 10)
    model = model.to(device)

    # 如果存在预训练权重则载入
    if os.path.exists(weights_path):
        weights_dict = torch.load(weights_path, map_location=device)
        load_weights_dict = {k: v for k, v in weights_dict.items()
                             if model.state_dict()[k].numel() == v.numel()}
        model.load_state_dict(load_weights_dict, strict=False)
    else:
        checkpoint_path = os.path.join("./weights", "initial_weights.pt")
        # 如果不存在预训练权重,需要将第一个进程中的权重保存,然后其他进程载入,保持初始化权重一致
        if rank == 0:
            torch.save(model.state_dict(), checkpoint_path)

        dist.barrier()
        # 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后的全连接层外,其他权重全部冻结
            if "fc" not in name:
                para.requires_grad_(False)
    else:
        # 只有训练带有BN结构的网络时使用SyncBatchNorm采用意义
        if args.syncBN:
            # 使用SyncBatchNorm后训练会更耗时
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)

    # 转为DDP模型
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    # optimizer
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)

        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        sum_num = evaluate(model=model,
                           data_loader=val_loader,
                           device=device)
        acc = sum_num / val_sampler.total_size

        if rank == 0:
            print(f"[epoch {epoch}] accuracy: {acc:.3f}")
            tags = ["loss", "accuracy", "learning_rate"]
            tb_writer.add_scalar(tags[0], mean_loss, epoch)
            tb_writer.add_scalar(tags[1], acc, epoch)
            tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

            torch.save(model.module.state_dict(), f"./weights/model-{epoch}.pth")

    # 删除临时缓存文件
    if rank == 0:
        if os.path.exists(checkpoint_path) is True:
            os.remove(checkpoint_path)

    cleanup()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.1)
    # 是否启用SyncBatchNorm
    parser.add_argument('--syncBN', type=bool, default=False)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    # parser.add_argument('--data-path', type=str, default="/home/wz/data_set/flower_data/flower_photos")

    # resnet34 官方权重下载地址
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    parser.add_argument('--weights', type=str, default='./weights/initial_weights.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    # 不要改该参数,系统会自动分配
    parser.add_argument('--device', default='cuda', help='device id (i.e. 0 or 0,1 or cpu)')
    # 开启的进程数(注意不是线程),不用设置该参数,会根据nproc_per_node自动设置
    parser.add_argument('--world-size', default=4, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    opt = parser.parse_args()

    main(opt)


训练命令

python -m torch.distributed.launch --nproc_per_node=4 --use_env mutil_gpu.py
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,569评论 4 363
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,499评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,271评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,087评论 0 209
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,474评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,670评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,911评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,636评论 0 202
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,397评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,607评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,093评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,418评论 2 254
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,074评论 3 237
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,092评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,865评论 0 196
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,726评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,627评论 2 270

推荐阅读更多精彩内容