PyTorch保存⽹络结构以及参数【torch.save()、torch.load()】
保存⽹络结构以及参数
对于pytorch保存⽹络参数,⼤家⼀般可以看到有 .pkl⽂件 以及 .pth⽂件,对于这两者有什么区别,以及如何保存⽹络参数等,本⽂就好好讲述⼀下。
⼀、保存⽅式
⾸先我们知道不论是保存模型还是参数都需要⽤到torch.save()。
对于torch.save()有两种保存⽅式:
只保存神经⽹络的训练模型的参数,save的对象是model.state_dict();
既保存整个神经⽹络的的模型结构⼜保存模型参数,那么save的对象就是整个模型;
Eg. 假设我有⼀个训练好的模型名叫model,如何来保存参数以及结构?
import torch
# 保存模型步骤
torch.save(model,'net.pth')# 保存整个神经⽹络的模型结构以及参数
torch.save(model,'net.pkl')# 同上
torch.save(model.state_dict(),'net_params.pth')# 只保存模型参数
torch.save(model.state_dict(),'net_params.pkl')# 同上
# 加载模型步骤
model = torch.load('net.pth')# 加载整个神经⽹络的模型结构以及参数
model = torch.load('net.pkl')# 同上
model.load_state_dict(torch.load('net_params.pth'))# 仅加载参数
model.load_state_dict(torch.load('net_params.pkl'))# 同上
上⾯例⼦也可以看出若使⽤torch.save()来进⾏模型参数的保存,那保存⽂件的后缀其实没有任何影响,.pkl ⽂件和 .pth ⽂件⼀模⼀样。
⼆、pkl、pth⽂件区别
实际上,这两种格式的⽂件还是有区别的。
2.1 .pkl⽂件
⾸先介绍 .pkl ⽂件,它若直接打开会显⽰⼀堆序列化的东西,以⼆进制形式存储的。如果去 read 这些⽂件,需要⽤'rb'⽽不是'r'模式。import pickle
import pickle as pkl
file= os.path.join('annot',model.pkl)# 打开pkl⽂件
with open(file,'rb')as anno_file:
result = pkl.load(anno_file)
或者:
import pickle as pkl
file= os.path.join('annot',model.pkl)# 打开pkl⽂件
anno_file =open(file,'rb')
result = pkl.load(anno_file)
2.2 .pth⽂件
import torch
filename = r'E:\anaconda\model.pth'# 字符串前⾯加r,表⽰的意思是禁⽌字符串转义
model = torch.load(filename)
print(model)
但其实不管pkl⽂件还是pth⽂件,都是以⼆进制形式存储的,没有本质上的区别,你⽤pickle这个库去加载pkl⽂件或pth⽂件,效果都是⼀样的。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论