五、keras callbacks使用攻略


文章代码来源:《deep learning on keras》,非常好的一本书,大家如果英语好,推荐直接阅读该书,如果时间不够,可以看看此系列文章,文章为我自己翻译的内容加上自己的一些思考,水平有限,多有不足,请多指正,翻译版权所有,若有转载,请先联系本人。
个人方向为数值计算,日后会向深度学习和计算问题的融合方面靠近,若有相近专业人士,欢迎联系。


系列文章:
一、搭建属于你的第一个神经网络
二、训练完的网络去哪里找
三、【keras实战】波士顿房价预测
四、keras的function API
五、keras callbacks使用
六、机器学习基础Ⅰ:机器学习的四个标签
七、机器学习基础Ⅱ:评估机器学习模型
八、机器学习基础Ⅲ:数据预处理、特征工程和特征学习
九、机器学习基础Ⅳ:过拟合和欠拟合
十、机器学习基础Ⅴ:机器学习的一般流程十一、计算机视觉中的深度学习:卷积神经网络介绍
十二、计算机视觉中的深度学习:从零开始训练卷积网络
十三、计算机视觉中的深度学习:使用预训练网络
十四、计算机视觉中的神经网络:可视化卷积网络所学到的东西


书中打了一个形象的比喻:我们之前训练模型就像扔纸飞机一样,叠好了,给个初速度,到底怎么飞,落到哪里,我们无法控制,今天我们要学会的就是如何造一个可以被控制的飞机。

使用callbacks来模型正在训练的时候来控制

我们之前训练的过程是先训练一遍,然后得到一个验证集的识别率变化趋势,从而知道最佳的epoch,设置epoch,再训练一遍,得到最终结果,这样很浪费时间。
一个好方法就是在测试识别率不再上升的时候,我们终止训练就可以了,callback可以帮助我们做到这一点,callback是一个obj类型的,它可以让模型去拟合,也常在各个点被调用。它和所有模型的状态和表现的数据,能够采取措施打断训练,保存模型,加载不同的权重,或者替代模型状态。
callbacks可以用来做这些事情:

  • 模型断点续训:保存当前模型的所有权重
  • 提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。
  • 动态调整训练时的参数,比如优化的学习速度。
  • 等等

earlystopping和modelcheckpoint

import keras
# Callbacks are passed to the model fit the `callbacks` argument in `fit`,
# which takes a list of callbacks. You can pass any number of callbacks.
callbacks_list = [
  # This callback will interrupt training when we have stopped improving
  keras.callbacks.EarlyStopping(
  # This callback will monitor the validation accuracy of the model
  monitor='acc',
  # Training will be interrupted when the accuracy
  # has stopped improving for *more* than 1 epochs (i.e. 2 epochs)
  patience=1,
  ),
  # This callback will save the current weights after every epoch
  keras.callbacks.ModelCheckpoint(
  filepath='my_model.h5', # Path to the destination model file
  # The two arguments below mean that we will not overwrite the
  # model file unless `val_loss` has improved, which
  # allows us to keep the best model every seen during training.
  monitor='val_loss',
  save_best_only=True,
  )
]
# Since we monitor `acc`, it should be part of the metrics of the model.
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
# Note that since the callback will be monitor validation accuracy,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
  epochs=10,
  batch_size=32,
  callbacks=callbacks_list,
  validation_data=(x_val, y_val))

monitor为选择的检测指标,我们这里选择检测'acc'识别率为指标,patience就是我们能让训练停止变好多少epochs才终止训练,这里选择了1,而modelcheckpoint就起到了存储最优的模型的作用,filepath为我们存储的位置和模型名称,以.h5为后缀,monitor为检测的指标,这里我们检测验证集里面的成功率,save_best_only代表我们只保存最优的训练结果。
而validation_data就是给定的验证集数据。

学习率减少callback

callbacks_list = [
  keras.callbacks.ReduceLROnPlateau(
  # This callback will monitor the validation loss of the model
  monitor='val_loss',
  # It will divide the learning by 10 when it gets triggered
  factor=0.1,
  # It will get triggered after the validation loss has stopped improving
  # for at least 10 epochs
  patience=10,
  )
]# Note that since the callback will be monitor validation loss,
# we need to pass some `validation_data` to our call to `fit`.
model.fit(x, y,
  epochs=10,
  batch_size=32,
  callbacks=callbacks_list,
  validation_data=(x_val, y_val))

翻译一下,就是如果连续10个批次,val_loss不再下降,就把学习率弄到原来的0.1倍。

自己造callback

如果内置的那些callback操作还满足不了你的需求,这里给出了如何自己造callback的方法。

# Called at the start of every epoch
on_epoch_begin
# Called at the end of every epoch
on_epoch_end
# Called right before processing each batch
on_batch_begin
# Called right after processing each batch
on_batch_end
# Called at the start of training
on_train_begin
# Called at the end of training
on_train_end

下面给出一个将激活值以数组的形式存进磁盘的callback:

import keras
import numpy as np
class ActivationLogger(keras.callbacks.Callback):
  def set_model(self, model):
  # This method is called by the parent model
  # before training, to inform the callback
  # of what model will be calling it
  self.model = model
  layer_outputs = [layer.output for layer in model.layers]
  # This is a model instance that returns the activations of every layer
  self.activations_model = keras.models.Model(model.input, layer_outputs)
  def on_epoch_end(self, epoch, logs=None):
  if self.validation_data is None:
  raise RuntimeError('Requires validation_data.')
  # Obtain first input sample of the validation data
  validation_sample = self.validation_data[0][0:1]
  activations = self.activations_model.predict(validation_sample)
  # Save arrays to disk
  f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
  np.savez(f, activations)
  f.close()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,560评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,104评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,297评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,869评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,275评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,563评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,833评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,543评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,245评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,512评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,011评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,359评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,006评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,062评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,825评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,590评论 2 273
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,501评论 2 268

推荐阅读更多精彩内容