将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt

最近由于工作上的需求,需要使用Tensorflow加载语言模型 SpanBERT(Facebook 发布的 BERT 模型的变体),但是作者只发布了 Pytorch 版的预训练权重,因此需要将其转换为 Tensorflow 可以加载的 checkpoint。

在 Pytorch 框架下,大多数开发者使用 Huggingface 发布的 Transformers 工具来加载语言模型,它同时支持加载 Pytorch 和 Tensorflow 版的模型。但是,目前基于 Tensorflow(或 Keras)的工具基本上都不支持加载 Pytorch 版的 bin 模型,转换代码在网上也很难找到,这带来了很多不便。

bert

通过搜索,目前能够找到的有以下几个转换代码片段可供参考:

  • bin2ckpt:用于转换 TinyBERT,但实测并不可用;
  • convert_pytorch_checkpoint_to_tf:Transformers 自带的转换脚本,但官方文档中并没有提及;
  • pytorch_to_tf:实体同指任务的一篇论文提供的转换脚本;
  • PyTorch 版的 BERT 转换成 Tensorflow 版的 BERT:VoidOc 编写的基于 Transformers 的转换脚本。

通过分析可以看到,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程并不复杂,可以分为以下几步:

  1. 读取出模型中每一层网络结构的名称和参数;
  2. 针对 PyTorch 和 Tensorflow 的模型格式差异对参数做一些调整;
  3. 按照 Tensorflow 的格式保存模型。

读取和保存模型

PyTorch 和 Tensorflow 框架都提供了模型的读取和保存功能,因此读取和保存语言模型的过程非常简单。

读取模型直接使用 PyTorch 自带函数 torch.load() 或者 Transformers 提供的对应模型包的 from_pretrained() 函数就可以了;而保存模型则使用 Tensorflow 自带的模型保存器 tf.train.Saver 来完成。

以 BERT 模型为例,读取模型的过程就是:

1
2
3
4
model = BertModel.from_pretrained(
pretrained_model_name_or_path=pytorch_bin_path,
state_dict=torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')
)

或者

1
model = torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')

模型的保存过程则通过 Tensorflow 提供的保存器 tf.train.Saver 来完成:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name) # 将层名称改为Tensorflow模型格式
torch_tensor = state_dict[var_name].numpy()
# 将参数矩阵改为Tensorflow模型格式
if any([x in var_name for x in tensors_to_transpose]):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor)
tf_weight = session.run(tf_var)
print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))

saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_").replace(".ckpt", "") + ".ckpt"))

整个过程就是先逐层读取模型的层名称和对应参数,然后将格式调整为 Tensorflow 模型的格式,再一次性写入到 checkpoint 文件中。

注意:部分转换脚本忽略了 reset_default_graph() 这一操作,会导致生成的 meta 文件不仅保存网络结构,还会保存完整的网络参数,从而体积庞大。

调整模型格式

由于 PyTorch 和 Tensorflow 的模型格式定义有所差异,因此转换的关键就是对部分层的名称和参数矩阵进行调整。具体来说,首先需要构建名称映射字典,对部分层的名称进行调整:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)

def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return "bert/{}".format(name)

注意:这里演示的是转换 BERT 模型,所以转换后的层名以 bert/ 开头。如果转换的是其他模型,需要做相应的修改。

然后,由于 PyTorch 和 Tensorflow 模型中 dense/kernel、attention/self/query、attention/self/key 和 attention/self/value 层的参数矩阵互为转置,因此还需要对模型中的对应层的参数进行调整:

1
2
3
4
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

if any([x in var_name for x in tensors_to_transpose]):
torch_tensor = torch_tensor.T

至此,转换过程就全部完成了。

完整的代码

综上所述,将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt 的过程还是比较清晰的。本文对 VoidOc 编写的脚本进行了进一步的简化,以转换 BERT 模型为例,完整的代码如下(Github):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# coding=utf-8

"""
Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.
"""

import numpy as np
import tensorflow.compat.v1 as tf
import torch
from transformers import BertModel
import os

def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):

"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:

Currently supported Huggingface models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
N BertForMultipleChoice
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
"""

tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

var_map = (
("layer.", "layer_"),
("word_embeddings.weight", "word_embeddings"),
("position_embeddings.weight", "position_embeddings"),
("token_type_embeddings.weight", "token_type_embeddings"),
(".", "/"),
("LayerNorm/weight", "LayerNorm/gamma"),
("LayerNorm/bias", "LayerNorm/beta"),
("weight", "kernel"),
)

if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)

state_dict = model.state_dict()

def to_tf_var_name(name: str):
for patt, repl in iter(var_map):
name = name.replace(patt, repl)
return "bert/{}".format(name)

def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var]))
session.run(tf_var)
return tf_var

tf.reset_default_graph()
with tf.Session() as session:
for var_name in state_dict:
tf_name = to_tf_var_name(var_name)
torch_tensor = state_dict[var_name].numpy()
if any([x in var_name for x in tensors_to_transpose]):
torch_tensor = torch_tensor.T
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
tf.keras.backend.set_value(tf_var, torch_tensor)
tf_weight = session.run(tf_var)
print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))

saver = tf.train.Saver(tf.trainable_variables())
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_").replace(".ckpt", "") + ".ckpt"))

def convert(pytorch_bin_path: str, pytorch_bin_model: str, tf_ckpt_path: str, tf_ckpt_model: str):

model = BertModel.from_pretrained(
pretrained_model_name_or_path=pytorch_bin_path,
state_dict=torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')
)

convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=tf_ckpt_path, model_name=tf_ckpt_model)

if __name__ == '__main__':
bin_path = './pretrained_model/pytorch_model/'
bin_model = 'pytorch_model.bin'
ckpt_path = './pretrained_model/tensorflow_model/'
ckpt_model = 'bert_model.ckpt'

convert(bin_path, bin_model, ckpt_path, ckpt_model)

转换过程被包装为 convert() 函数,输入 PyTorch 版 bin 模型的路径和名称,以及 Tensorflow 版 ckpt 的保存路径和名称即可。

再次提醒一下,由于本文转换的 SpanBERT 只是 BERT 的一个变体,因此模型的层名称是与 BERT 模型完全一致的,如果需要转换其他模型,请自行修改 to_tf_var_name() 函数和 tensors_to_transpose 变量。