MATLAB调⽤Pytorch模型
前⾔
在⾃⼰的⼯作中遇到了要使⽤Pytorch训练神经⽹络但仿真程序却是在Matlab上的情况,为了得到⼀个快速的进展,就有了在Matlab调⽤Pytorch神经⽹络的想法。总的来说,还是⽐较好实现,但是作为新⼿难免会遇到些问题,所以就把⾃⼰的经验写在这篇帖⼦⾥,以便⽇后复习,也希望能给其他⼈⼀些思路。
由于最近⽐较忙,可能更新的速度会⽐较慢,但是会⼀直更新完。
平台和软件
Matlab: 2021a
Python: 3.8.3
我是直接使⽤Anaconda下载的Python.要查看不同Matlab版本⽀持的Python版本,请参考。
以当前较新的Python版本为例,2020b和2021a都⽀持python3.7、3.8。
基本测试
验证 Python 配置
这⾥参考即可,主要是在Matlabli熟悉⼀下python的基本操作。
MATLAB 到 Python 的数据类型映射
这个主要是查看数据传递到python后,数据的类型变化 ()
这⼀点很重要,因为数据格式的不匹配使得python⽆法识别传递的参数。
调⽤⽤户定义的 Python 模块
这⾥建议从最简单的调⽤开始尝试。
1. ⽆参数传递+⽆参数返回
2. 有参数传递+⽆参数返回
3. 有参数传递+有参数返回
传递关键字参数
pyargs
python新手函数这⾥⼀定要现实
重点 - 调⽤神经⽹络
以我最近在⽤的ViT模型为例⼦,python的函数如下:
number_of_classes: 输出的类别数⽬
image_input: 是测试的图⽚。这⾥需要注意,pytorch模型的输⼊图⽚的维度是 N * C * H * W. 因此从Matlab中传递给python的数据格式也应该是这样的。(区别:Matlab中神经⽹络使⽤的输⼊图⽚维度:H * W * C * N)
trained_ViT_path:训练好的模型路径
patch_size_given:图⽚的patch⼤⼩
def vit_prediction(number_of_classes,image_input,patch_size_given,trained_ViT_path :str):
#type(number_of_classes) # 可以⽤来观察传递到python的数据的类型
vit_model = ViT(
image_size = 64,
patch_size = patch_size_given,
num_classes = number_of_classes,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
# print(vit_model)
model_load_path= trained_ViT_path
try:
## 上载训练好的模型(整个)
vit_model.load_state_dict(torch.load(model_load_path))
print('done 1')
except:
## 上载训练好的模型(匹配的部分)
pretrained_dict=torch.load(model_load_path, map_location=torch.device('cpu'))
model_dict = vit_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys_cpu
model_dict.update(pretrained_dict)
print ('done 2')
print('model loaded')
# 需要转换格式类型
image_input_array = np.asarray(image_input)
image_input_tensor = torch.from_numpy(image_input_array)
# 下⾯的程序就是使⽤神经⽹络预测,后期会补上
print("预测")
return pred_results
把上述代码保存成py的⽂件。可以在python环境下先进⾏测试,是否有代码上的错误,然后再使⽤Matlab调⽤。
MATLAB的调⽤代码如下
test_data_py的维度是N * C * H * W
pred_results = py.ViT_prediction_function_GPU.vit_prediction(pyargs(...
'number_of_modes',int8(number_of_classes),'image_input',test_data_py,'patch_size_given',int8(patch_size),'trained_ViT_path',trained_ViT_model_path)); pred_results = double(pred_results);
*我在windows下调⽤没有任何问题,最新的ViT也没有什么问题。不过,当我想在Linux上实现同样的调⽤时,遇到了⼀个问题:只要import XX 就会报错,matlab 也会崩溃。⽬前还没有解决,后期如果解决了,会上传⼀个在linux下调⽤的介绍。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论