Java调⽤Keras、Tensorflow模型
实现python离线训练模型,Java在线预测部署。查看原⽂
⽬前深度学习主流使⽤python训练⾃⼰的模型,有⾮常多的框架提供了能快速搭建神经⽹络的功能,其中Keras提供了high-level的语法,底层可以使⽤tensorflow或者theano。
但是有很多公司后台应⽤是⽤Java开发的,如果⽤python提供HTTP接⼝,对业务延迟要求⽐较⾼的话,仍然会有⼀定得延迟,所以能不能使⽤Java调⽤模型,python可以离线的训练模型?(tensorflow也提供了成熟的部署⽅案TensorFlow Serving)
⼿头上有⼀个⽤Keras训练的模型,⽹上关于Java调⽤Keras模型的资料不是很多,⽽且⼤部分是重复的,并且也没有讲的很详细。⼤致有两种⽅案,⼀种是基于Java的深度学习库导⼊Keras模型实现,另外⼀种是⽤tensorflow提供的Java接⼝调⽤。
Deeplearning4J
Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AI AI to business environments for use on distributed GPUs and CPUs.
Deeplearning4j⽬前⽀持导⼊Keras训练的模型,并且提供了类似python中numpy的⼀些功能,更⽅便地处理结构化的数据。遗憾的
是,Deeplearning4j现在只覆盖了Keras <2.0版本的⼤部分Layer,如果你是⽤Keras 2.0以上的版本,在导⼊模型的时候可能会报错。
了解更多:
Keras Model Import: Supported Features
Importing Models From Keras to Deeplearning4j
Tensorflow
⽂档,Java的⽂档很少,不过调⽤模型的过程也很简单。采⽤这种⽅式调⽤模型需要先将Keras导出的模型转成tensorflow的protobuf协议的模型。
1、Keras的h5模型转为pb模型
在Keras中使⽤model.save(model.h5)保存当前模型为HDF5格式的⽂件中。
python转java代码Keras的后端框架使⽤的是tensorflow,所以先把模型导出为pb模型。在Java中只需要调⽤模型进⾏预测,所以将当前的graph中的Variable全部变成Constant,并且使⽤训练后的weight。以下是freeze graph的代码:
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论