pytorch如何不受权重文件限制,随心所欲定义模型

问题引入

闲来无事,自己搭建了yolov3的backbone darknet53玩一玩,使用pytorch搭建完成后,去yolov3的官网下载了yolo_weight.pth权重文件,并不能直接载入,有这样几个问题:

  • 我搭建的模型只是backbone部分,而官方的权重文件包含了yolov3_head部分,如何去掉yolov3_head只载入backbone部分的权重成为了第一个问题。
  • 我在搭建darkent53的时候只是参照了yolov3论文的图片,只有一些残差结构的内部去看了源码,所以我整体构建的方法和每一层的命名与官方不同。
  • 第三个问题是我打印了我自己的参数字典(model_dict)和官方权重的参数字典(weights_dict),我发现我的所有bn层的参数都比官网的多了一个参数num_batches_tracked,这是由于官方训练得到的权重文件是基于torch0.3.1,版本太老导致的。
    左边weights_dict,右边图model_dict.png

解决

我通过两个步骤解决了问题,首先过滤掉model_dict中的num_batches_tracked,之后使用循环进行遍历赋值。

##############################################################
#  > File Name        : darknet53.py
#  > Author           : zhw
#  > Created Time     : 2021年12月31日 星期五 22时12分45秒
##############################################################
import torch 
import torch.nn as nn
import math
from collections import OrderedDict

def ConvBNLRelu(in_channels, out_channels, kernel, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.1)
        )

class BasicBlock(nn.Module):
    def __init__(self, in_channels, channels_list):
        super(BasicBlock,self).__init__()
        self.conv1 = ConvBNLRelu(in_channels, channels_list[0], 1, 1)
        self.conv2 = ConvBNLRelu(channels_list[0], channels_list[1], 3, 1, 1)
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x += residual
        return x
class Darknet53(nn.Module):
    def __init__(self):
        super(Darknet53, self).__init__()
        self.in_channels = 32
        layers = [1, 2, 8, 8, 4]
        self.conv1 = nn.Conv2d(3,self.in_channels,3,1,1,bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu1 = nn.LeakyReLU(0.1)
        self.layer1 = self._make_layer([32,64], layers[0])
        self.layer2 = self._make_layer([64,128], layers[1])
        self.layer3 = self._make_layer([128,256], layers[2])
        self.layer4 = self._make_layer([256,512], layers[3])
        self.layer5 = self._make_layer([512,1024], layers[4])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def _make_layer(self, channels_list, blocks):
        layers = []
        layers.append(("ds_conv", nn.Conv2d(self.in_channels, channels_list[1], 3, 2, 1, bias=False)))
        layers.append(("ds_bn", nn.BatchNorm2d(channels_list[1])))
        layers.append(("ds_relu", nn.LeakyReLU(0.1)))
        self.in_channels = channels_list[1]
        for i in range(blocks):
            layers.append(("residual_{}".format(i), BasicBlock(self.in_channels, channels_list)))
        return nn.Sequential(OrderedDict(layers))
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out2 = self.layer3(out)
        out1 = self.layer4(out2)
        out0 = self.layer5(out1)
        return out0, out1, out2

def main():
    t = torch.randn([4,3,416,416])
    model = Darknet53()
    weights_path = "/home/zhw/dataset/weights/object_detection/yolo_weights.pth"
    weights_dict = torch.load(weights_path)
    weights_list_key = list(weights_dict.keys())
    len_weights = len(weights_list_key)
    model_dict = model.state_dict()
    # 过滤掉num_batches_tracked参数
    model_dict = {k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k}
    model_list_key = list(model_dict.keys())
    len_model_dict = len(model_list_key)
    m,n = 0,0
    # 循环赋值,并保证shape一致
    while m < len_weights and n < len_model_dict:
        weights_name,model_name = weights_list_key[m],model_list_key[n]
        weights_shape,model_shape = weights_dict[weights_name].shape,model_dict[model_name].shape
        if weights_shape != model_shape:
           continue
        model_dict[model_name] = weights_dict[weights_name] 
        n += 1
        m += 1
    model.load_state_dict(model_dict)
    if n == min(len_weights, len_model_dict):
        print("all weights was loaded")
    #out0, out1, out2 = model(t) 
    #print("out0 shape: ", out0.shape)
    #print("out1 shape: ", out1.shape)
    #print("out2 shape: ", out2.shape)

if __name__ == "__main__":
    main()

总结

这种方法对好多情况其实都适用,但有的时候我们搭建的模型名字都与权重一致,只是想加载部分权重的话,其实用不着这么麻烦,之后我会给出总结。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 162,825评论 4 377
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 68,887评论 2 308
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 112,425评论 0 255
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,801评论 0 224
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 53,252评论 3 299
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 41,089评论 1 226
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 32,216评论 2 322
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 31,005评论 0 215
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,747评论 1 250
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,883评论 2 255
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,354评论 1 265
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,694评论 3 265
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,406评论 3 246
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,222评论 0 9
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,996评论 0 201
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 36,242评论 2 287
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 36,017评论 2 281

推荐阅读更多精彩内容