将tensorflow版本的.ckpt模型转成pytorch的.bin模型
⽤google-research官⽅的bert源码(tensorflow版本)对新的法律语料进⾏微调,迭代次数为100000次,每隔1000次保存⼀下模型,得到的结果如下:
将最后三个⽂件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_a
加上之前微调使⽤过的config.json以及⽂件,运⾏如下⽂件后⽣成pytorch.bin,之后就可以被pytorch得代码调⽤了。
1from__future__import absolute_import
2from__future__import division
3from__future__import print_function
4
5import argparse
6import torch
7
8from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
9
10import logging
11 logging.basicConfig(level=logging.INFO)
12
13def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
14# Initialise PyTorch model
15    config = BertConfig.from_json_file(bert_config_file)
16print("Building PyTorch model from configuration: {}".format(str(config)))
17    model = BertForPreTraining(config)
18
19# Load weights from tf checkpoint
20    load_tf_weights_in_bert(model, config, tf_checkpoint_path)
21
22# Save pytorch-model
23print("Save PyTorch model to {}".format(pytorch_dump_path))
24    torch.save(model.state_dict(), pytorch_dump_path)
25
26#
27if__name__ == "__main__":
28    parser = argparse.ArgumentParser()
29## Required parameters
30    parser.add_argument("--tf_checkpoint_path",
31                        default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
32                        type = str,
33                        help = "Path to the TensorFlow checkpoint path.")
34    parser.add_argument("--bert_config_file",
35                        default = './chinese_L-12_H-768_A-12_improve1/config.json',
36                        type = str,
37                        help = "The config json file corresponding to the pre-trained BERT model. \n"
38"This specifies the model architecture.")
39    parser.add_argument("--pytorch_dump_path",
40                        default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
41                        type = str,
42                        help = "Path to the output PyTorch model.")
43    args = parser.parse_args()
44    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
45                                      args.bert_config_file,
46                                      args.pytorch_dump_path)
Tip:如果不是BERT模型,是BERT模型的变种,⽐如MobileBERT,DistilBERT等,数据形式可能不匹配,报错AttributeError:
tensorflow版本选择
'BertForPreTraining' object has no attribute 'bias'
此时需要根据transformers库⾥的源码修改convert_tf_checkpoint_to_pytorch函数,以MobileBERT为例
1#参考transformers库⾥的transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
2from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
3
4
5def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
6# Initialise PyTorch model
7    config = MobileBertConfig.from_json_file(mobilebert_config_file)
8print(f"Building PyTorch model from configuration: {config}")
9    model = MobileBertForPreTraining(config)
10# Load weights from tf checkpoint
11    model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
12# Save pytorch-model
13print(f"Save PyTorch model to {pytorch_dump_path}")
14    torch.save(model.state_dict(), pytorch_dump_path)

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。