DARTS代码阅读

0x00 背景知识

先放上一篇综述文章,对于理解NAS(网络结构搜索)的问题有很大的帮助:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/82321884
另外,DARTS搜索,强烈建议先看下inception的网络结构和nasnet的论文,DARTS的论文基础是建立在之上的,某种程度上可以看做是对nasnet的优化。

0x01 搜索思路

基于前人的经验(inception/nasnet),DARTS使用cell作为模型结构搜索的基础单元,所学习的单元堆叠成卷积网络,也可以递归连接形成递归网络。
cell内节点间先默认所有可能的操作连接,每个连接初始化权重参数值,结构搜索也就是训练这些权重参数,最终两节点间选取权重最大的操作作为最终结构参数。

训练过程中,交替训练网络结构参数和网络参数。

0x02 代码定义

genotype结构定义

normal=[(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 1), (‘skip_connect’, 0), (‘skip_connect’, 0), (‘dil_conv_3x3’, 2)], normal_concat=[2, 3, 4, 5]

取了genotype里的一个normal cell的定义及其对应的cell结构图首先说明下,这个定义的解释。DARTS搜索的也就是这个定义。
normal定义里(‘sep_conv_3x3’, 1)的0,1,2,3,4,5对应到图中的红色字体标注的。
从normal文字定义两个元组一组,映射到图中一个蓝色方框的节点(这个是作者搜索出来的结构,结构不一样,对应关系不一定是这样的)
sep_conv_xxxx表示操作,0/1表示输入来源
(‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0) —-> 节点0
(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1) —-> 节点1
(‘sep_conv_3x3’, 1), (‘skip_connect’, 0) —-> 节点2
(‘skip_connect’, 0), (‘dil_conv_3x3’, 2) —-> 节点3
normal_concat=[2, 3, 4, 5] —-> cell输出c_{k}

DARTS搜索NOTE

首先明确,DARTS搜索实际只搜cell内结构,整个模型的网络结构是预定好的,比如多少层,网络宽度,cell内几个节点等;
在构建搜索的网络结构时,有几个特别的地方:
1.预构建cell时,采用的一个MixedOp:包含了两个节点所有可能的连接(genotype中的PRIMITIVES);
2.初始化了一个alphas矩阵,网络做forward时,参数传入,在cell里使用,搜索过程中所有可能连接都在时,计算mixedOp的输出,采用加权的形式。
3.训练过程对train数据每个step又切成两份: train和validate, train用来训练网络参数,validate用来训练结构参数。

0x03 关键代码片段

以下把代码中一些关键的,影响到理解DARTS的地方说明一下:

  • file: train_search.py 第149行
    architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
  logits = model(input)
  loss = criterion(logits, target)
  loss.backward()
  nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
  optimizer.step()

这里就是论文里近似后的交叉梯度下降,其中architect.step()是结构参数weights的梯度下降,optimizer.step()是网络参数的梯度下降。

  • file: model_search.py
class MixedOp(nn.Module):
  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
      self._ops.append(op)
  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops)) # weighted op

这个是MixedOp,两节点间操作把PRIMITIVES里定义的所有操作都连接上,计算输出时利用传入的weights进行加权。

  • file: model_search.py第47行
def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)
    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states)) # all nodes before can be input, mixop.
      offset += len(states) #0, 2, 5, 9
      states.append(s)
    return torch.cat(states[-self._multiplier:], dim=1)

self.ops[], 实际是14(2+3+4+5)个MixedOp,2+3+4+5的解释,对于第一个内部节点,有两个可能的输入(c{k-1}, c_{k-2}),对于第二个内部节点,有三个可能的输入(两个同节点1,另加上第一个节点)……
代码里,weights[],也是一个长度14的list,前2个对应到第一个节点的两个输入的权重,第3~5这3个元素对应到第二个节点的三个输入的权重……这就是上面代码里offset的作用

  • file: architect.py 第11行
class Architect(object):
  def __init__(self, model, args):
    self.network_momentum = args.momentum
    self.network_weight_decay = args.weight_decay
    self.model = model
    self.optimizer = torch.optim.Adam(self.model.arch_parameters(),   #arch_parameters, 
        lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 

需要注意的是Architect里optimizer优化器的参数是model.arch_parameters(), 这个对应到的是model_search.py里定义的._arch_parameters,及初始化的各节点连接的权重。
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i)) # 2+i, 2 for two inputs, i=0,1,2,3, nodes before this. 2+3+4+5
num_ops = len(PRIMITIVES)

self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

  • file: model_search.py 第133行
def _parse(weights):
      #  weights: [2 + 3 + 4 + 5][len(PRIMITIVES)]
      gene = []
      n = 2
      start = 0
      for i in range(self._steps): #ch: steps = 4
        end = start + n 
        print('start=', start, 'end=', end, 'n=', n)
        W = weights[start:end].copy()
        print(W) # ch: add
        # chenhua: for x, -max(W[x][...]), W[][] is the parameters for architect. lambda elect out the OP weights most.
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        print(edges)
        for j in edges: #ch: j, edges mean op, all possible ops between two node
          print(j)
          k_best = None
          for k in range(len(W[j])):  #ch: k, the weights for possible connection?
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                print('W[j][k]=', W[j][k], 'W[j][k_best]=', W[j][k_best])
                k_best = k
          gene.append((PRIMITIVES[k_best], j))  #ch: find ????
        start = end
        n += 1
      return gene
    # ch: alphas_xxx, parameters for architect??
    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
    concat = range(2+self._steps-self._multiplier, self._steps+2) #ch: step=4, mltiplier=3
    print('concat', concat)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    print('genotype=', genotype)
    return genotype

搜索过程中搜索出的结果(节点间的op)的打印,就是靠这个函数。
核心是找出两个节点间不为none的所有ops中权重最大的,就是最终的结果。
注意:weights[][]的size是[2 + 3 + 4 + 5][len(PRIMITIVES)]

参考链接

  1. https://cloud.tencent.com/developer/article/1348049
  2. https://blog.csdn.net/srdlaplace/article/details/80863346
  3. https://www.jiqizhixin.com/articles/2018-06-27-6
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 160,026评论 4 364
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,655评论 1 296
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,726评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,204评论 0 213
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,558评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,731评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,944评论 2 314
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,698评论 0 203
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,438评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,633评论 2 247
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,125评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,444评论 3 255
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,137评论 3 238
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,103评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,888评论 0 197
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,772评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,669评论 2 271

推荐阅读更多精彩内容