将TF-checkpoint 文件转换为 pytorch-checkpoint 踩坑

  1. 改代码将Bert的Tensorflow 检查点转换为 Pytorch的检查点,整理Transformers的代码得到,为了方便使用同时记录踩的坑。

  2. Tensorflow检查点文件解析。

1. 包括以下3个文件
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
2. 其中model.ckpt为checkpoint的文件前缀,在命令行调用该代码提供 --tf_checkpoint_path 时需要同时提供checkpoint 前缀,例如 --tf_checkpoint_path model_checkpoint/model.ckpt
  1. 同时提供模型Config文件,名字通常为bert_config.json。

  2. 调用该代码命令行为:

# 依赖自行下载
# $checkpoint_path 为TF-checkpoint路径
# $save_file 为pytorch-checkpoint 保存文件
python3 convert_bert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $checkpoint_path/model.ckpt --bert_config_file $checkpoint_path/bert_config.json --pytorch_dump_path $save_file
  1. 保存后得到一个 pytorch-checkpoint, 需要同 bert_config.json 和 vocab.txt在同一个文件夹,同时需要将Bert_config.json增加一个命名为config.json的文件,Transformers加载Pytorch模型时会自动调用,之后可以通过Transformers正常使用。

  2. 目前该代码已经保存至 https://github.com/YaoXinZhi/Convert-Bert-TF-checkpoint-to-Pytorch