PyTorch源码解读之torch.serializationtorch.hub
作者 | 123456
来源 | OpenMMLab
编辑 | 极市平台
导读import pickle
本⽂解读基于PyTorch 1.7版本,对torch.serialization、torch.save和torch.hub展开介绍。
torch.serialization
torch.serialization 实现对 PyTorch 对象结构的⼆进制序列化和反序列化,其中序列化由 torch.save 实现,反序列化由 torch.load 实现。
torch.save
torch.save 主要使⽤ pickle 来进⾏⼆进制序列化:
def save(obj, # 待序列化的对象
f: Union[str, os.PathLike, BinaryIO], # 带写⼊的⽂件
pickle_module=pickle, # 默认使⽤ pickle 进⾏序列化
pickle_protocol=DEFAULT_PROTOCOL, # 默认使⽤ pickle 第2版协议
_use_new_zipfile_serialization=True) -> None: # pytorch 1.6 之后默认使⽤基于 zipfile 的存储⽂件格式, 如果想⽤旧的格式,
# 可设为False. torch.load 同时⽀持新旧格式⽂件的读取.
# 如果使⽤ dill 进⾏序列化操作, dill的版本需⼤于 0.3.1.
_check_dill_version(pickle_module)
with _open_file_like(f, 'wb') as opened_file:
# 基于 zipfile 的存储格式
if _use_new_zipfile_serialization:
with _open_zipfile_writer(opened_file) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
return
# 以⼆进制⽅式写⼊⽂件
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
可以看到核⼼函数是 _save(),_legacy_save() ,接下来分别介绍,我们⾸先介绍_save()函数:
def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {} # 暂存具体数据内容以及其对应的key
def persistent_id(obj):
if torch.is_storage(obj): # 如果是需要存储的数据内容
storage_type = normalize_storage_type(type(obj)) # 存储类型,int, float, ...
obj_key = str(obj._cdata) # 数据内容对应的key. 在load时根据key读取数据
location = location_tag(obj) # cpu 还是cuda
serialized_storages[obj_key] = obj # 数据及其对应的key
return ('storage', storage_type, obj_key, location, obj.size()) # 注意这⾥没有具体数据,只返回数据相关的信息
return None
data_buf = io.BytesIO() # 开辟 buffer
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) # 对象的结构信息即将写⼊ data_buf 中
pickler.persistent_id = persistent_id # 将对象的结构信息写⼊ data_buf 中,具体数据内容暂存在 serialized_storages 中
pickler.dump(obj) # 对对象执⾏写⼊操作,写⼊过程会调 persistent_id 函数
data_value = value() # 将写⼊的对象的结构信息取出来
zip_file.write_record('data.pkl', data_value, len(data_value)) # 写⼊到存储⽂件 zip_file 中,注意这⾥写⼊的信息只是对象的结构 # 信息(通过 data.pkl 来标识),具体数据内容还未写⼊
for key in sorted(serialized_storages.keys()): # 写⼊数据内容
name = f'data/{key}' # 数据的名字
storage = serialized_storages[key] # 具体数据内容
if pe == 'cpu': # 数据在 cpu 上
num_bytes = storage.size() * storage.element_size() # 计算占⽤的字节数
zip_file.write_record(name, storage.data_ptr(), num_bytes) # 写⼊数据
else: # 数据在 cuda 上
buf = io.BytesIO() # 开辟 buffer
storage._write_file(buf, _should_read_directly(buf), False) # 将 cuda 上的数据复制到内存中
buf_value = value() # 读取内存中的数据
zip_file.write_record(name, buf_value, len(buf_value)) # 写⼊数据
总的来说 _save() 函数在将对象⼆进制序列化的过程中,⾸先写⼊对象的结构信息,之后再写⼊具体的数据内容。
接下来介绍_legacy_save()函数:
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
as nn
serialized_container_types = {}
serialized_storages = {}
def persistent_id(obj: Any) -> Optional[Tuple]:
if isinstance(obj, type) and issubclass(obj, nn.Module): # 记录 source code
if obj in serialized_container_types: # 如果已经记录过⼀样的,不需要重复记录
return None
serialized_container_types[obj] = True
source_file = source = None
try:
source_lines, _, source_file = get_source_lines_and_file(obj) # 读取 source code
source = ''.join(source_lines) # 读取 source code
except Exception: # 不到的话,打印warning
warnings.warn("Couldn't retrieve source code for container of "
"type " + obj.__name__ + ". It won't be checked "
"for correctness upon loading.")
return ('module', obj, source_file, source)
elif torch.is_storage(obj): # 与上⾯ `_save()` 中 `persistent_id()` 的对应内容类似
view_metadata: Optional[Tuple[str, int, int]]
obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj))
offset = 0
obj_key = str(obj._cdata)
location = location_tag(obj)
serialized_storages[obj_key] = obj
is_view = obj._cdata != obj._cdata
if is_view:
view_metadata = (str(obj._cdata), offset, obj.size())
else:
view_metadata = None
return ('storage', storage_type, obj_key, location, obj.size(),
view_metadata)
return None
# 记录⼀些系统信息
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == 'little',
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) # 记录 MAGIC_NUMBER,⽤于load时验证⽂件是否损坏
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) # 记录 pickle 协议,⽤于load时验证pickle协议是否⼀致 pickle_module.dump(sys_info, f, protocol=pickle_protocol) # 记录⼀些系统信息
pickler = pickle_module.Pickler(f, protocol=pickle_protocol) # 对象的结构信息即将写⼊⽂件中
pickler.persistent_id = persistent_id # 将对象的结构信息写⼊ data_buf 中,具体数据内容暂存在 serialized_storages 中
pickler.dump(obj) # 执⾏写⼊操作,期间会调⽤ persistent_id() 函数
serialized_storage_keys = sorted(serialized_storages.keys())
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) # 写⼊具体数据对应的 key
f.flush() # 刷新缓存区
for key in serialized_storage_keys:
serialized_storages[key]._write_file(f, _should_read_directly(f), True) # 写⼊具体数据
可以看到_legacy_save()和_save() 在序列化的过程中,整体的pipeline是类似的,只是写⼊的内容有轻微差别。
torch.load
torch.load 主要使⽤ pickle 来进⾏⼆进制反序列化。
def load(f, # 待反序列化的⽂件
map_location=None, # 将对象放到cpu或cuda上,默认与⽂件⾥对象的location⼀致
pickle_module=pickle, # 默认使⽤pickle来反序列化
**pickle_load_args):
_check_dill_version(pickle_module)
if 'encoding' not in pickle_load_args.keys(): # 默认使⽤ utf-8 解码
pickle_load_args['encoding'] = 'utf-8'
with _open_file_like(f, 'rb') as opened_file:
if _is_zipfile(opened_file): # 如果是基于 zipfile 的存储格式
orig_position = ll()
with _open_zipfile_reader(opened_file) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile): # 如果存的torchscript⽂件,⽤torch.jit.load().否则⽤_load()反序列化 warnings.warn(
"'torch.load' received a zip file that looks like a TorchScript archive"
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
" silence this warning)", UserWarning)
opened_file.seek(orig_position)
return torch.jit.load(opened_file)
return _load(opened_zipfile, map_location, pickle_module,
**pickle_load_args)
# 对⼆进制⽂件,⽤_legacy_load()反序列化
return _legacy_load(opened_file, map_location, pickle_module,
**pickle_load_args)
可以看到核⼼函数是_load(),_legacy_load(),接下来分别介绍,我们⾸先介绍_load()函数:
def _load(zip_file,
map_location,
pickle_module,
pickle_file='data.pkl', # 注意这⾥的'data.pkl'与_save()中的⼀⼀对应
**pickle_load_args):
restore_location = _get_restore_location(map_location) # 根据map_location来⽣成restore_location函数,⽤于将数据放在cpu或cuda上 loaded_storages = {}
def load_tensor(data_type, size, key, location):
name = f'data/{key}' # 数据的key,⽤于寻数据
dtype = data_type(0).dtype # 数据类型,⽐如 int, float, ...
storage = _storage_from_record(name, size, dtype).storage() # 从⽂件中到数据
loaded_storages[key] = restore_location(storage, location) # 放到 cpu 或 cuda 上
def persistent_load(saved_id):
assert isinstance(saved_id, tuple) # save_id = ('storage', storage_type, obj_key, location, obj.size())
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
data_type, key, location, size = data # data_type, key, location, size = storage_type, obj_key, location, obj.size()
if key not in loaded_storages:
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
storage = loaded_storages[key]
return storage
data_file = io.BytesIO(_record(pickle_file)) # 读取对象的配置⽂件`data.pkl`,存储的对象的结构信息
unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load # ⽤于读取具体数据的persistent_load函数
result = unpickler.load() # 执⾏读取操作
torch._utils._validate_loaded_sparse_tensors()
return result
总的来说 _load() 函数在将对象⼆进制反序列化的过程中,在构建对象结构信息的同时,就已经将具体的数据内容加载进来了。_legacy_load()函数与它不同,_legacy_load()是先构建对象结构信息,再加载具体的数据。
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location) # 根据map_location来⽣成restore_location函数,⽤于将数据放在cpu或cuda上
def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
# 由于不是基于 zipfile 的存储格式,报错退出,之后代码不会执⾏
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
...
deserialized_objects = {}
def persistent_load(saved_id):
assert isinstance(saved_id, tuple) # saved_id = ('storage', storage_type, obj_key, location, obj.size(), view_metadata)
# or saved_id = ('module', obj, source_file, source)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论