从Pytorch源码看.pt⽂件Pytorch中张量的保存与加载
保存张量
在Pytorch中,⼀个约定俗成的⽅法是使⽤.pt扩展的⽂件格式来保存张量,使⽤的⽅法为torch.save()。函数原型与参数说明
import torch
def save(obj, f: Union[str, os.PathLike, BinaryIO],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)->None: """
pytorch框架的原型代码
"""
pass
# 参数说明
# obj:要保存的对象,类型为tensor
# f:保存的⽂件名,可以是⽂件路径(包含⽂件名的字符串)、可以是字符流、也可以是⽂件对象
# pickle:Python中的⼀个模块,实现了⽤于序列化和反序列化Python对象结构的⼆进制协议
# pickle_module:⽤来协议化元数据和对象的协议
import pickle
# pickle_protocol:可以指定来默认覆盖的协议
# 使⽤save⽅法
def save_tensor():
# 直接保存为⼀个张量
x = torch.Tensor([1,2,3])
torch.save(x,'save_tensor.pt')
# 保存为字符流的格式
buffer= io.BytesIO()
torch.save(x,buffer)
加载张量
在Pytorch中,使⽤torch.load()⽅法加载torch.save()⽅法保存的⽂件。
函数原型与参数说明
import torch
def load(f, map_location=None, pickle_module=pickle,**pickle_load_args):
"""
Pytorch框架的原型代码
"""
pass
# 参数说明
# f:保存的⽂件名
# map_location:加载位置,即将这个张量加载到哪,可选的内容包括:函数、torch.device、字符串以及指定如何重新映射存储的字典# pickle_module:⽤来协议化元数据和对象的协议
# pickle_load_args:需要加载的pickle模块的参数设置。这个包含的内容相当丰富,感兴趣的可以去阅读Pytorch的官⽅⼿册
# 使⽤load⽅法
def tensor_load():
# ⼩⽩式加载(最常⽤)
torch.load('save_tensor.pt')
# 加载到CPU中
torch.load('save_tensor.pt', map_location=torch.device('cpu'))
# 使⽤函数加载到CPU中
torch.load('save_tensor.pt', map_location=lambda storage, loc: storage)
# 加载到GPU1中
torch.load('save_tensor.pt', map_location=lambda storage, loc: storage.cuda(1))
# 从GPU0加载到GPU1中
torch.load('save_tensor.pt', map_location={'cuda: 1':'cuda: 0'})
# 指定加载的编码⽅式
torch.load('save_tensor.pt', encoding='ascii')
# 加载字符流格式的张量
with open('save_tensor.pt','rb')as f:
buffer= io.ad())
torch.load(buffer)

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。