model = BertModel.from_pretrained( pretrained_model_name_or_path=pytorch_bin_path, state_dict=torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu') )
或者
1
model = torch.load(os.path.join(pytorch_bin_path, pytorch_bin_model), map_location='cpu')
模型的保存过程则通过 Tensorflow 提供的保存器 tf.train.Saver 来完成:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
tf.reset_default_graph() with tf.Session() as session: for var_name in state_dict: tf_name = to_tf_var_name(var_name) # 将层名称改为Tensorflow模型格式 torch_tensor = state_dict[var_name].numpy() # 将参数矩阵改为Tensorflow模型格式 ifany([x in var_name for x in tensors_to_transpose]): torch_tensor = torch_tensor.T tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf.keras.backend.set_value(tf_var, torch_tensor) tf_weight = session.run(tf_var) print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
""" :param model:BertModel Pytorch model instance to be converted :param ckpt_dir: Tensorflow model directory :param model_name: model name :return: Currently supported Huggingface models: Y BertModel N BertForMaskedLM N BertForPreTraining N BertForMultipleChoice N BertForNextSentencePrediction N BertForSequenceClassification N BertForQuestionAnswering """
tf.reset_default_graph() with tf.Session() as session: for var_name in state_dict: tf_name = to_tf_var_name(var_name) torch_tensor = state_dict[var_name].numpy() ifany([x in var_name for x in tensors_to_transpose]): torch_tensor = torch_tensor.T tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf.keras.backend.set_value(tf_var, torch_tensor) tf_weight = session.run(tf_var) print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))