[Keras] ModelCheckpoint 无法保存多 gpu 训练的模型

问题描述

在使用 callbacks.ModelCheckpoint() 并进行多 gpu 并行计算时,callbacks 函数会报错:

TypeError: can't pickle ...(different text at different situation) objects

这个错误形式其实跟使用多 gpu 训练时保存模型不当造成的错误比较相似:

To save the multi-gpu model, use .save(fname) or .save_weights(fname)
with the template model (the argument you passed to multi_gpu_model),
rather than the model returned by multi_gpu_model.

这个问题在我之前的文章中也有提到:[Keras] 使用Keras调用多GPU,并保存模型
。显然,在使用检查点时,默认还是使用了 paralleled_model.save() ,进而导致错误。为了解决这个问题,我们需要自己定义一个召回函数。

解决方法

法一

original_model = ...
parallel_model = multi_gpu_model(original_model, gpus=n)

class MyCbk(keras.callbacks.Callback):

    def __init__(self, model):
         self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('model_at_epoch_%d.h5' % epoch)

cbk = MyCbk(original_model)
parallel_model.fit(..., callbacks=[cbk])

法二

class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        self.single_model = model
        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)

check_point = ParallelModelCheckpoint(single_model ,'best.hd5')

法三

class CustomModelCheckpoint(keras.callbacks.Callback):

    def __init__(self, model, path):
        self.model = model
        self.path = path
        self.best_loss = np.inf

    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs['val_loss']
        if val_loss < self.best_loss:
            print("\nValidation loss decreased from {} to {}, saving model".format(self.best_loss, val_loss))
            self.model.save_weights(self.path, overwrite=True)
            self.best_loss = val_loss

model.fit(X_train, y_train,
              batch_size=batch_size*G, epochs=nb_epoch, verbose=0, shuffle=True,
              validation_data=(X_valid, y_valid),
              callbacks=[CustomModelCheckpoint(model, '/path/to/save/model.h5')])

参考资料

推荐阅读更多精彩内容

  • Training spaCy’s Statistical Models训练spaCy模型 This guide d...
    Joe_Gao_89f1阅读 5,330评论 1 5
  • 1.男儿当自强,一条硬脊梁; 双肩担日月,胸纳海河江; 手掌乾坤家国事,脚踏天地负兴亡! 男儿身,当自强, 生当男...
    君子羊阅读 289评论 0 0
  • cell 折叠效果 OC版 感谢 来自集成的方法在github上面readMe里面有很详细的说明 github.c...
    Smile_J阅读 1,674评论 0 10
  • 妞妞熟睡。 先是在车上睡,下车后放到小推车上还继续睡,推着小车到了花园里,我把她抱到放置在草地中的小床上,还在继续...
    牧田麻麻阅读 57评论 0 0
  • 2017年2月23日 闹铃在早晨的6:30响起,7:30挣扎着从床上爬起来,新的一天又开始了…… 撑开眼皮看着周围...
    11lili阅读 48评论 0 0