精简CNN模型系列之七:Xception

介绍

Xception是Google出品,属于2017年左右的东东。它在Google家的MobileNet v1之后,MobileNet v2之前。
它的主旨与MobileNet系列很像即推动Depthwise Conv + Pointwise Conv的使用。只是它直接以Inception v3为模子,将里面的基本inception module替换为使用Depthwise Conv + Pointwise Conv,又外加了residual connects,
最终模型在ImageNet等数据集上都取得了相比Inception v3与Resnet-152更好的结果。当然其模型大小与计算效率相对Inception v3也取得了较大提高。

从Inception module到Separable Conv的演变之路

下图1为一个典型的Inception module,它的实现的基本assumption就是feature在经conv处理时可分别学习feature channels间的关联关系与feature单个channel内部空间上的关联关系,为此inception module中使用了大量的1x1 conv来重视学习
channels之间的关联,然后再分别使用3x3/5x5(两个3x3)等去学习其不同维度上的单个channel内的空间上的关联;若我们基于以上inception中用到的关联关系分离假设而只使用3x3 convs来表示单个channel内的空间关联关系,那么就可以得到出下图2表示的
简化后了的inception module。

经典版与简化版Inception模块

而本质上上图2中表示的简化版Inception模块又可被表示为下图3中的形式。可以看出实质上它等价于先使用一个1x1 conv来学习input feature maps之上channels间特征的关联关系,然后再将1x1 conv输出的feature maps进行分割,分别交由下面的若干个3x3
conv来处理其内的空间上元素的关联关系。

等价的简化版Inception模块

更进一步,何不做事做绝将每个channel上的空间关联分别使用一个相应的conv 3x3来单独处理呢。如此就得到了下图4中所示的Separable conv。

Separable_conv

Xception架构

下图中为Xception结构的表示。它就是由Inception v3直接演变而来。其中引入了Residual learning的结构(已经有多项工作,同时在本文中作者也有相关实验表明Residual learning在CNN模型中的使用可带来收敛速度的加快。)。

Xception架构

同一向复杂的Inception系列模型一样,它也引入了Entry/Middle/Exit三个flow,每个flow内部使用不同的重复模块,当然最最核心的属于中间不断分析、过滤特征的Middel flow。
Entry flow主要是用来不断下采样,减小空间维度;中间则是不断学习关联关系,优化特征;最终则是汇总、整理特征,用于交由FC来进行表达。

实验结果

下表为Xception与其它模型在Imagenet上分类精度的结果比较。

Xception与其它模型在Imagenet上的结果比较

然后下表则为Xception与Inception v3在模型参数大小与计算速度上的比较。

Xception与Inception_v3在模型大小及速度上的比较

代码分析

在TF的官方实现中,他们对Xception的模型结构做了些改变,主要如下两块。一是使用stride为2的conv来代替使用Maxpooling进行feature降维;二则是在depthwise conv之后同样使用ReLu与BN。

We made a few more changes on top of MSRA's modifications:
1. Fully convolutional: All the max-pooling layers are replaced with separable
  conv2d with stride = 2. This allows us to use atrous convolution to extract
  feature maps at any resolution.
2. We support adding ReLU and BatchNorm after depthwise convolution, motivated
  by the design of MobileNetv1.

以下为Xception的模型构建入口函数。

def xception(inputs,
             blocks,
             num_classes=None,
             is_training=True,
             global_pool=True,
             keep_prob=0.5,
             output_stride=None,
             reuse=None,
             scope=None):
  """Generator for Xception models.
  This function generates a family of Xception models. See the xception_*()
  methods for specific model instantiations, obtained by selecting different
  block instantiations that produce Xception of various depths.
  Args:
    inputs: A tensor of size [batch, height_in, width_in, channels]. Must be
      floating point. If a pretrained checkpoint is used, pixel values should be
      the same as during training (see go/slim-classification-models for
      specifics).
    blocks: A list of length equal to the number of Xception blocks. Each
      element is an Xception Block object describing the units in the block.
    num_classes: Number of predicted classes for classification tasks.
      If 0 or None, we return the features before the logit layer.
    is_training: whether batch_norm layers are in training mode.
    global_pool: If True, we perform global average pooling before computing the
      logits. Set to True for image classification, False for dense prediction.
    keep_prob: Keep probability used in the pre-logits dropout layer.
    output_stride: If None, then the output will be computed at the nominal
      network stride. If output_stride is not None, it specifies the requested
      ratio of input to output spatial resolution.
    reuse: whether or not the network and its variables should be reused. To be
      able to reuse 'scope' must be given.
    scope: Optional variable_scope.
  Returns:
    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
      If global_pool is False, then height_out and width_out are reduced by a
      factor of output_stride compared to the respective height_in and width_in,
      else both height_out and width_out equal one. If num_classes is 0 or None,
      then net is the output of the last Xception block, potentially after
      global average pooling. If num_classes is a non-zero integer, net contains
      the pre-softmax activations.
    end_points: A dictionary from components of the network to the corresponding
      activation.
  Raises:
    ValueError: If the target output_stride is not valid.
  """
  with tf.variable_scope(
      scope, 'xception', [inputs], reuse=reuse) as sc:
    end_points_collection = sc.original_name_scope + 'end_points'
    with slim.arg_scope([slim.conv2d,
                         slim.separable_conv2d,
                         xception_module,
                         stack_blocks_dense],
                        outputs_collections=end_points_collection):
      with slim.arg_scope([slim.batch_norm], is_training=is_training):
        net = inputs
        if output_stride is not None:
          if output_stride % 2 != 0:
            raise ValueError('The output_stride needs to be a multiple of 2.')
          output_stride /= 2
        # Root block function operated on inputs.
        net = resnet_utils.conv2d_same(net, 32, 3, stride=2,
                                       scope='entry_flow/conv1_1')
        net = resnet_utils.conv2d_same(net, 64, 3, stride=1,
                                       scope='entry_flow/conv1_2')

        # Extract features for entry_flow, middle_flow, and exit_flow.
        net = stack_blocks_dense(net, blocks, output_stride)

        # Convert end_points_collection into a dictionary of end_points.
        end_points = slim.utils.convert_collection_to_dict(
            end_points_collection, clear_collection=True)

        if global_pool:
          # Global average pooling.
          net = tf.reduce_mean(net, [1, 2], name='global_pool', keepdims=True)
          end_points['global_pool'] = net
        if num_classes:
          net = slim.dropout(net, keep_prob=keep_prob, is_training=is_training,
                             scope='prelogits_dropout')
          net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                            normalizer_fn=None, scope='logits')
          end_points[sc.name + '/logits'] = net
          end_points['predictions'] = slim.softmax(net, scope='predictions')
        return net, end_points

以下函数则是构建具体的blocks。。感觉Google有过度设计的嫌疑,非得将一个简单模型构得那么多层次,即使是slim这样已经高级了些的model让人看起来仍然要追着看很久。。推荐大家使用Pytorch啊。。:)。

就到这了,再向下不追了,还想进一步探究的朋友可以直接去翻参考文献里标明的code file。

@slim.add_arg_scope
def stack_blocks_dense(net,
                       blocks,
                       output_stride=None,
                       outputs_collections=None):
  """Stacks Xception blocks and controls output feature density.
  First, this function creates scopes for the Xception in the form of
  'block_name/unit_1', 'block_name/unit_2', etc.
  Second, this function allows the user to explicitly control the output
  stride, which is the ratio of the input to output spatial resolution. This
  is useful for dense prediction tasks such as semantic segmentation or
  object detection.
  Control of the output feature density is implemented by atrous convolution.
  Args:
    net: A tensor of size [batch, height, width, channels].
    blocks: A list of length equal to the number of Xception blocks. Each
      element is an Xception Block object describing the units in the block.
    output_stride: If None, then the output will be computed at the nominal
      network stride. If output_stride is not None, it specifies the requested
      ratio of input to output spatial resolution, which needs to be equal to
      the product of unit strides from the start up to some level of Xception.
      For example, if the Xception employs units with strides 1, 2, 1, 3, 4, 1,
      then valid values for the output_stride are 1, 2, 6, 24 or None (which
      is equivalent to output_stride=24).
    outputs_collections: Collection to add the Xception block outputs.
  Returns:
    net: Output tensor with stride equal to the specified output_stride.
  Raises:
    ValueError: If the target output_stride is not valid.
  """
  # The current_stride variable keeps track of the effective stride of the
  # activations. This allows us to invoke atrous convolution whenever applying
  # the next residual unit would result in the activations having stride larger
  # than the target output_stride.
  current_stride = 1

  # The atrous convolution rate parameter.
  rate = 1

  for block in blocks:
    with tf.variable_scope(block.scope, 'block', [net]) as sc:
      for i, unit in enumerate(block.args):
        if output_stride is not None and current_stride > output_stride:
          raise ValueError('The target output_stride cannot be reached.')
        with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
          # If we have reached the target output_stride, then we need to employ
          # atrous convolution with stride=1 and multiply the atrous rate by the
          # current unit's stride for use in subsequent layers.
          if output_stride is not None and current_stride == output_stride:
            net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
            rate *= unit.get('stride', 1)
          else:
            net = block.unit_fn(net, rate=1, **unit)
            current_stride *= unit.get('stride', 1)

      # Collect activations at the block's end before performing subsampling.
      net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)

  if output_stride is not None and current_stride != output_stride:
    raise ValueError('The target output_stride cannot be reached.')

  return net

参考文献

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,233评论 4 360
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,013评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,030评论 0 241
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,827评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,221评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,542评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,814评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,513评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,225评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,497评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,998评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,342评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,986评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,055评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,812评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,560评论 2 271
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,461评论 2 266

推荐阅读更多精彩内容

  • 姓名:章津津 部门:财务部 职位:出纳 【知~学习】读六项精进4遍 【内容感悟分享】谦虚 【行~实践】 一、修身:...
    立早章_50c7阅读 136评论 0 0
  • 所谓“刘三姐”,是我嫡亲的三姐,在我刘氏一门八个兄弟姐妹当中行五,在五姊妹当中行三,故戏称“刘三姐”。今年芳龄五十...
    人生何处不欢乐阅读 1,278评论 7 6