Pytorch——保存训练好的模型参数
⽂章⽬录
1.前⾔
训练好了⼀个模型, 我们当然想要保存它, 留到下次要⽤的时候直接提取直接⽤,下⾯我将来讲如何存储训练好的模型参数
⾸先,先搭建⼀个神经⽹络
import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11)# 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调⽤的话,打印出来的随机数每次都不⼀样
x = torch.unsqueeze(torch.linspace(-1,1,100), dim=1)# [100] -> [100,1]
y = x.pow(2)+0.5*torch.rand(x.size())# y的形状与x⼀样
def make_and_save_model():
network = Sequential(
)
optimizer = torch.optim.SGD(network.parameters(), lr=0.3)#优化器
criterion = MSELoss()#损失函数
# 训练
for i in range(200):
prediction = network(x)#数据放⼊模型后得到预测值
loss = criterion(prediction, y)#计算预测值与真实值之间的误差
<_grad()#清空梯度
loss.backward()#误差反向传播
optimizer.step()#更新参数
torch.save(network,'network.pth')# 保存整个⽹络
torch.save(network.state_dict(),'network_params.pth')# 只保存⽹络中的参数
plt.figure(1, figsize =(10,3))
plt.subplot(131)
plt.title('network')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'yo', lw =5)
plt.pause(1)
这种⽅式将会提取整个神经⽹络, ⽹络⼤的时候可能会⽐较慢.
def load_whole_model():
network_whole = torch.load('network.pth')
prediction = network_whole(x)
plt.figure(1, figsize =(10,3))
plt.subplot(132)
plt.title('network_whole')
linspace numpy
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'yo', lw =5)
plt.pause(1)
这种⽅式将会提取所有的参数, 然后再放到你的新建⽹络中
def load_only_params():
network_params = Sequential(
)
network_params.load_state_dict(torch.load('network_params.pth'))
prediction = network_params(x)
plt.figure(1, figsize =(10,3))
plt.subplot(133)
plt.title('network_params')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(),'yo', lw =5)
5.调⽤三个函数
会看到加载后的模型画出的图是⼀样的,说明模型的参数正确加载了。
make_and_save_model()
load_whole_model()
load_only_params()
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论