将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt
最近由于工作上的需求,需要使用Tensorflow加载语言模型 SpanBERT(Facebook 发布的 BERT 模型的变体),但是作者只发布了 Pytorch 版的预训练权重,因此需要将其转换为 Tensorflow 可以加载的 checkpoint。
在 Pytorch 框架下,大多数开发者使用 Huggingface 发布的 Transformers 工具来加载语言模型,它同时支持加载 Pytorch 和 Tensorflow 版的模型。但是,目前基于 Tensorflow(或 Keras)的工具基本上都不支持加载 Pytorch 版的 bin 模型,转换代码在网上也很难找到,这带来了很多不便。
bert
通过搜索,目前能够找到的有以下几个转换代码片段可供参考:
- bin2ckpt:用于转换 TinyBERT,但实测并不可用;
- convert_pytorch_checkpoint_to_tf:Transformers 自带的转换脚本,但官方文档中并没有提及;
- pytorch_to_tf:实体同指任务的一篇论文提供的转换脚本;
- PyTorch 版的 BERT 转换成 Tensorflow 版的 BERT:VoidOc 编写的基于 Transformers 的转换脚本。
通过分析可以看到,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程并不复杂,可以分为以下几步:
- 读取出模型中每一层网络结构的名称和参数;
- 针对 PyTorch 和 Tensorflow 的模型格式差异对参数做一些调整;
- 按照 Tensorflow 的格式保存模型。
读取和保存模型
PyTorch 和 Tensorflow 框架都提供了模型的读取和保存功能,因此读取和保存语言模型的过程非常简单。
读取模型直接使用 PyTorch 自带函数 torch.load() 或者 Transformers 提供的对应模型包的 from_pretrained() 函数就可以了;而保存模型则使用 Tensorflow 自带的模型保存器 tf.train.Saver 来完成。
以 BERT 模型为例,读取模型的过程就是:
1 | model = BertModel.from_pretrained( |
或者
1 | model = torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu') |
模型的保存过程则通过 Tensorflow 提供的保存器 tf.train.Saver 来完成:
1 | tf.reset_default_graph() |
整个过程就是先逐层读取模型的层名称和对应参数,然后将格式调整为 Tensorflow 模型的格式,再一次性写入到 checkpoint 文件中。
注意:部分转换脚本忽略了 reset_default_graph() 这一操作,会导致生成的 meta 文件不仅保存网络结构,还会保存完整的网络参数,从而体积庞大。
调整模型格式
由于 PyTorch 和 Tensorflow 的模型格式定义有所差异,因此转换的关键就是对部分层的名称和参数矩阵进行调整。具体来说,首先需要构建名称映射字典,对部分层的名称进行调整:
1 | var_map = ( |
注意:这里演示的是转换 BERT 模型,所以转换后的层名以 bert/ 开头。如果转换的是其他模型,需要做相应的修改。
然后,由于 PyTorch 和 Tensorflow 模型中 dense/kernel、attention/self/query、attention/self/key 和 attention/self/value 层的参数矩阵互为转置,因此还需要对模型中的对应层的参数进行调整:
1 | tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") |
至此,转换过程就全部完成了。
完整的代码
综上所述,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程还是比较清晰的。本文对 VoidOc 编写的脚本进行了进一步的简化,以转换 BERT 模型为例,完整的代码如下(Github):
1 | # coding=utf-8 |
转换过程被包装为 convert() 函数,输入 PyTorch 版 bin 模型的路径和名称,以及 Tensorflow 版 ckpt 的保存路径和名称即可。
再次提醒一下,由于本文转换的 SpanBERT 只是 BERT 的一个变体,因此模型的层名称是与 BERT 模型完全一致的,如果需要转换其他模型,请自行修改 to_tf_var_name() 函数和 tensors_to_transpose 变量。