五、keras callbacks使用攻略


文章代码来源:《deep learning on keras》,非常好的一本书,大家如果英语好,推荐直接阅读该书,如果时间不够,可以看看此系列文章,文章为我自己翻译的内容加上自己的一些思考,水平有限,多有不足,请多指正,翻译版权所有,若有转载,请先联系本人。
个人方向为数值计算,日后会向深度学习和计算问题的融合方面靠近,若有相近专业人士,欢迎联系。


系列文章:
一、搭建属于你的第一个神经网络
二、训练完的网络去哪里找
三、【keras实战】波士顿房价预测
四、keras的function API
五、keras callbacks使用
六、机器学习基础Ⅰ:机器学习的四个标签
七、机器学习基础Ⅱ:评估机器学习模型
八、机器学习基础Ⅲ:数据预处理、特征工程和特征学习
九、机器学习基础Ⅳ:过拟合和欠拟合
十、机器学习基础Ⅴ:机器学习的一般流程十一、计算机视觉中的深度学习:卷积神经网络介绍
十二、计算机视觉中的深度学习:从零开始训练卷积网络
十三、计算机视觉中的深度学习:使用预训练网络
十四、计算机视觉中的神经网络:可视化卷积网络所学到的东西


书中打了一个形象的比喻:我们之前训练模型就像扔纸飞机一样,叠好了,给个初速度,到底怎么飞,落到哪里,我们无法控制,今天我们要学会的就是如何造一个可以被控制的飞机。

使用callbacks来模型正在训练的时候来控制

我们之前训练的过程是先训练一遍,然后得到一个验证集的识别率变化趋势,从而知道最佳的epoch,设置epoch,再训练一遍,得到最终结果,这样很浪费时间。
一个好方法就是在测试识别率不再上升的时候,我们终止训练就可以了,callback可以帮助我们做到这一点,callback是一个obj类型的,它可以让模型去拟合,也常在各个点被调用。它和所有模型的状态和表现的数据,能够采取措施打断训练,保存模型,加载不同的权重,或者替代模型状态。
callbacks可以用来做这些事情:

  • 模型断点续训:保存当前模型的所有权重
  • 提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。
  • 动态调整训练时的参数,比如优化的学习速度。
  • 等等

earlystopping和modelcheckpoint

import keras
# Callbacks are passed to the model fit the `callbacks` argument in `fit`,
# which takes a list of callbacks. You can pass any number of callbacks.
callbacks_list = [
  # This callback will interrupt training when we have stopped improving
  keras.callbacks.EarlyStopping(
  # This callback will monitor the validation accuracy of the model
  monitor='acc',
  # Training will be interrupted when the accuracy
  # has stopped improving for *more* than 1 epochs (i.e. 2 epochs)
  patience=1,
  ),
  # This callback will save the current weights after every epoch
  keras.callbacks.ModelCheckpoint(
  filepath='my_model.h5', # Path to the destination model file
  # The two arguments below mean that we will not overwrite the
  # model file unless `val_loss` has improved, which
  # allows us to keep the best model every seen during training.
  monitor='val_loss',
  save_best_only=True,
  )
]
# Since we monitor `acc`, it should be part of the metrics of the model.
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
# Note that since the callback will be monitor validation accuracy,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
  epochs=10,
  batch_size=32,
  callbacks=callbacks_list,
  validation_data=(x_val, y_val))

monitor为选择的检测指标,我们这里选择检测'acc'识别率为指标,patience就是我们能让训练停止变好多少epochs才终止训练,这里选择了1,而modelcheckpoint就起到了存储最优的模型的作用,filepath为我们存储的位置和模型名称,以.h5为后缀,monitor为检测的指标,这里我们检测验证集里面的成功率,save_best_only代表我们只保存最优的训练结果。
而validation_data就是给定的验证集数据。

学习率减少callback

callbacks_list = [
  keras.callbacks.ReduceLROnPlateau(
  # This callback will monitor the validation loss of the model
  monitor='val_loss',
  # It will divide the learning by 10 when it gets triggered
  factor=0.1,
  # It will get triggered after the validation loss has stopped improving
  # for at least 10 epochs
  patience=10,
  )
]# Note that since the callback will be monitor validation loss,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
  epochs=10,
  batch_size=32,
  callbacks=callbacks_list,
  validation_data=(x_val, y_val))

翻译一下,就是如果连续10个批次,val_loss不再下降,就把学习率弄到原来的0.1倍。

自己造callback

如果内置的那些callback操作还满足不了你的需求,这里给出了如何自己造callback的方法。

# Called at the start of every epoch
on_epoch_begin
# Called at the end of every epoch
on_epoch_end
# Called right before processing each batch
on_batch_begin
# Called right after processing each batch
on_batch_end
# Called at the start of training
on_train_begin
# Called at the end of training
on_train_end

下面给出一个将激活值以数组的形式存进磁盘的callback:

import keras
import numpy as np
class ActivationLogger(keras.callbacks.Callback):
  def set_model(self, model):
  # This method is called by the parent model
  # before training, to inform the callback
  # of what model will be calling it
  self.model = model
  layer_outputs = [layer.output for layer in model.layers]
  # This is a model instance that returns the activations of every layer
  self.activations_model = keras.models.Model(model.input, layer_outputs)
  def on_epoch_end(self, epoch, logs=None):
  if self.validation_data is None:
  raise RuntimeError('Requires validation_data.')
  # Obtain first input sample of the validation data
  validation_sample = self.validation_data[0][0:1]
  activations = self.activations_model.predict(validation_sample)
  # Save arrays to disk
  f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
  np.savez(f, activations)
  f.close()

推荐阅读更多精彩内容