cifar10 官方例子详解

本文按照程序执行的顺序进行详解。执行: python cifar10_train.py 进行训练:

1、首先进入 cifar10_train.py 的 main() 函数。先 调用 cifar10.py 的 maybe_download_and_extract() 下载数据

-------- cifar10.maybe_download_and_extract() --------

2、maybe_download_and_extract() 中先拼接出数据文件的路径 filepath (exp: /tmp/cifar10_data/cifar-10-binary.tar.gz)

3、初始调用时数据文件不存在需要下载,调用 urllib.request.urlretrieve 进行下载

4、调用 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 对下载的压缩包进行解压,后返回  cifar10_train.py 的 main() 函数

-------- main() --------

5、重置 train_dir:如果存在,则先删掉,再创建;通过调用 tensorflow.python.platform 中的 gfile 里的 Exists 、 DeleteRecursively 、 MakeDirs 方法来实现

6、开始训练,进入 train()

-------- train() --------

7、调用 tf.Graph().as_default() 创建一个图,并作为以下所有操作默认的图。通过 with,将以下所有的操作都限定在该图中;

8、创建变量 global_step ,初始化为 0 ,后续作为 train_op 操作的输入参数

9、调用 cifar10.distorted_inputs() 获取 images (128, 24, 24, 3) 和 labels (128,)

-------- cifar10.distorted_inputs() --------

10、得到数据所在目录 data_dir (exp: /tmp/cifar10_data/cifar-10-batches-bin),再调用 cifar10_input.distorted_inputs(data_dir,batches_size) 返回数据

-------- cifar10_input.distorted_inputs() --------

11、得到数据文件路径数组 filenames ,包含有 5 个数据文件的路径:

filenames = ['/tmp/cifar10_data/cifar-10-batches-bin/data_batch_1.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_2.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_3.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_4.bin',

                    '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_5.bin']

12、filename_queue = tf.train.string_input_producer(filenames) 生成一个文件队列对象,然后将该队列传入 read_cifar10(filename_queue) 读取数据

-------- cifar10_input.read_cifar10() --------

13、定义要返回的对象 result = CIFAR10Record() 

result.height = 32   #图片高度

result.width = 32    #图片宽度

result.depth = 3    #图片深度,RGB 三色

13、cifar10 的数据文件为二进制文件,其中每个记录的长度是固定的,1 个字节的标签,然后 3072 字节的图像数据。 record_bytes 即为每个记录的长度 —— 1+3072=3073 ,通过 FixedLengthRecordReader(record_bytes=record_bytes) 生成一个阅读操作器 reader

14、result.key, value = reader.read(filename_queue) 给 reader 传入I/O类型的参数filename_queue,返回一个 tensor(我们现在写的这些读取的代码,仅仅是在画 graph,在操作 run 执行并不会真的执行。仅代表 graph 中的一个节点)

15、record_bytes = tf.decode_raw(value, tf.uint8) 操作将一个字符串转化为一个 unit8 张量

16、将张量 record_bytes 中的第一个字符——标签取出,转化为 int32 类型,赋值给 result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)

17、再从张量 record_bytes 中取出第二部分——图片数据,原始图片数据是 [depth, height, width] ,需要通过 tf.transpose 函数转化为 [height, width, depth] 。此时可得到一张 unit8 格式的 [height, width, depth] 图片矩阵,赋值给 result.uint8image 

18、最终返回的 result 是一个 CIFAR10Record 对象,包含:

    height: 图片高度

    width: 图片宽度

    depth: 图片通道

    key: 描述 filename 和 record number 的 Tensor

    label: a int32 Tensor with the label in the range 0..9.

    uint8image: a [height, width, depth] uint8 Tensor with the image data

-------- cifar10_input.distorted_inputs() --------

19、继续对返回的数据进行处理、预处理。首先是 将图片格式从 unit8 转为 float32

        reshaped_image = tf.cast(read_input.uint8image, tf.float32)

20、对图片进行扩充,通过 随机裁剪 —— tf.random_crop 、左右翻转 —— tf.image.random_flip_left_right 、亮度变化 —— tf.image.random_brightness 、对比度变化 —— tf.image.random_contrast 、归一化处理 —— tf.image.per_image_standardization

21、再调用 _generate_image_and_label_batch ,将多个图片 Tensor 合并成 batch-Tensor 

-------- cifar10_input._generate_image_and_label_batch() --------

22、通过调用 tf.train.shuffle_batch 对样本进行乱序批处理,大致原理是,将样本的 Tensor 按顺序压到一个队列 RandomShuffleQueue 中,直到样本个数达到 capacity ,然后需要的时候随机从中取出 batch_size 个样本

images, label_batch = tf.train.shuffle_batch(

        [image, label],

        batch_size=batch_size,

        num_threads=num_preprocess_threads,

        capacity=min_queue_examples + 3 * batch_size,

        min_after_dequeue=min_queue_examples)

23、最后通过 tf.summary.image('images', images) 将 images 保存到 tensorflow board 中

24、最后返回的是 一个 batch_size 的 images 和 labels 的 Tensor ,返回给 cifar10_input.distorted_inputs() ,再返回给 cifar10.distorted_inputs() ,再返回给 cifar10_train.train()

-------- 数据处理结束 --------

-------- cifar10_train.train() --------

25、数据处理之后,就是要在 graph 中增加 关于神经网络 model 的 操作,调用 cifar10.inference(images)

-------- cifar10.inference() --------

26、在 该 model 中,共有 conv1 、pool1 、norm1 、 conv2 、pool2 、norm2 、 local3 、 local4 、 softmax 层,最终返回 softmax 层。

27、conv1 层,先创建一个 scope ,主要是对改成的变量进行统一命名:

        with tf.variable_scope('conv1') as scope:

28、conv1 层, 再调用 _variable_with_weight_decay 初始化卷积核参数

-------- cifar10._variable_with_weight_decay() --------

29、初始化参数 ,初始化使用 tf.truncated_normal_initializer 即截断正太分布,使用 tf.get_variable 来初始化参数,得到初始化参数 var ,并且在 wd > 0 时,会增加 L2范式稀疏化: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') ,wd 为衰减系数。然后使用 tf.add_to_collection('losses',weight_decay) 将 weight_decay 作为以 losses 为标签进行收集。这个例子里,只有全连接层对稀疏性有需求

-------- cifar10.inference() --------

30、conv1 层,先调用 tf.nn.conv2d ,再调用 tf.nn.bias_add ,最后调用 tf.nn.relu 生成第一个卷积层操作。该卷积层输入层数为 3,输出层数为 64

31、调用 _activation_summary 将 conv1 层输出到 tensorboard 中

32、pool1 层,通过 tf.nn.max_pool 生成池化层

33、norm1 层,通过 tf.nn.lrn ,对第一个 pooling 层进行局部响应归一化

34、conv2、pool2、norm2 类似

35、local3 层,为全连接层,高度为 384,将 pool2 展开成1纬,然后乘以权重、加上偏移,激活函数用 relu ,调用 _activation_summary 将 local3层输出到 tensorboard 中

36、local4 层,为全连接层,高度为 192,类似 local3

37、softmax 层,输出层,最终将该层返回

-------- 创建模型结束 --------

38、接下来定义 loss 操作,进入 cifar10.loss(logits, labels) ,输入 返回输出层 和 标准输出

-------- cifar10.loss() --------

39、传入的 labels 是 (batch_size,) ,需要转化为 (batch_size, 10),即将 labels = [3,5] 转化成 dense_labels = [[0,0,0,1,0,0,0,0,0,0],[0,0,0,0,1,0,0,0,0,0]]。需要先构造出 concated = [[0,3],[1,5]],代表在最终的 (batch_size, 10) 矩阵中 1 所在的坐标,然后再通过 tf.sparse_to_dense 得到最终的矩阵。其中 concated 可以通过 tf.concat([indices, sparse_labels],1) 得到,其中 indices 是 [[0],[1]] ,sparse_labels 是 [[3],[5]] 

40、计算交叉熵: tf.nn.softmax_cross_entropy_with_logits ,得到的是 [batch_size] 的张量,再调用 tf.reduce_mean 求平均,得到 cross_entropy_mean 平均交叉熵

41、最终将 cross_entropy_mean 跟 L2 的范式部分相加得到总的损失函数,并返回

-------- 创建损失函数结束 --------

42、接下来定义 train 操作,进入 train(total_loss, global_step) ,输入 总损失函数 和 step

-------- cifar10.train() --------

43、首先调用 lr = tf.train.exponential_decay 生成一个随着 steps 指数衰减的 learning_rate 

44、opt = tf.train.GradientDescentOptimizer(lr) 生成梯度递减操作 opt

45、grads = opt.compute_gradients(total_loss) 

推荐阅读更多精彩内容