Java调⽤Tensorflow训练模型预测结果
Java调⽤Tensorflow训练好的模型做预测,⾸先需要读取词典,然后加载模型,读⼊数据,最后预测结果。模型训练参考上⼀篇博客:
session如何设置和读取⾸先需要下载⼀些包,如果是maven项⽬在l中添加两个依赖。
<dependency>
<groupId&sorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId&sorflow</groupId>
<artifactId>libtensorflow_jni</artifactId>
<version>1.5.0</version>
</dependency>
读取词典⽂件
这个词典⽂件,就是上⼀篇对应训练模型之前⽣成的词典⽂件。每⾏⼀个词和词的编号。// 从⽂件读取词典⽂件存⼊Map
private static Map<String, Integer>readVocabFromFile(String pathname)throws IOException{
Map<String, Integer> wordMap =new HashMap<String, Integer>();
File filename =new File(pathname);
InputStreamReader reader =new InputStreamReader(new FileInputStream(filename));
BufferedReader br =new BufferedReader(reader);
String line ="";
line = br.readLine();
String[] lineArray;
while(line != null){
lineArray = line.split(" ");
wordMap.put(lineArray[0], Integer.parseInt(lineArray[1]));
line = br.readLine();
}
return wordMap;
}
加载Tensorflow模型⽂件
这⾥加载上⼀篇中训练完成保存的模型⽂件lstm_attention.pb。
// 读取tensorflow⼆进制的模型⽂件
private static byte[]readAllBytes(String pathname)throws IOException{
File filename =new File(pathname);
BufferedInputStream in =new BufferedInputStream(new FileInputStream(filename));
ByteArrayOutputStream out =new ByteArrayOutputStream(1024);
byte[] temp =new byte[1024];
int size =0;
while((size = in.read(temp))!=-1){
out.write(temp,0, size);
}
in.close();
byte[] content = ByteArray();
return content;
}
读取预测数据
预测可以是⼀条数据,也可以是⼀个batch的数据。
// 读取分词后的⼀个样本,并建⽴索引
public static int[][]getInputFromSentence(String sentence, Map<String, Integer> wordIndexMap){
int[][] indexArray =new int[1][MAX_SEQUENCE_LENGTH];
String[] words = sentence.split(" ");
for(int i=0; i<words.length; i++){
ainsKey(words[i])){
indexArray[0][i]= (words[i]);
}
}
return indexArray;
}
// 对⼀个batch的样本建⽴索引
public static int[][]getInputFromSentenceBatch(String[] sentences, Map<String, Integer> wordIndexMap){ int[][] indexArray =new int[sentences.length][MAX_SEQUENCE_LENGTH];
for(int i=0; i<sentences.length; i++){
String[] words = sentences[i].split(" ");
for(int j=0; j<words.length; j++){
ainsKey(words[j])){
indexArray[i][j]= (words[j]);
}
}
}
return indexArray;
}
预测结果
需要新建Tensorflow的Session会话,读取训练好的模型计算图和参数,
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
sorflow.Graph;
sorflow.Session;
sorflow.Tensor;
public class TensorflowDemo {
private static String TensorFlow_MODEL_PATH ="lstm_attention.pb";
private static String WORD_INDEX_PATH ="";
private static int MAX_SEQUENCE_LENGTH =60;
private static int CLASS_NUM =2;
public static void main(String[] args)throws IOException{
// 构建词典Map
Map<String, Integer> wordsMap =readVocabFromFile(WORD_INDEX_PATH);        System.out.println("vocabulary size:"+wordsMap.size());
// 加载Tensorflow训练好的模型
byte[] graphDef =readAllBytes(TensorFlow_MODEL_PATH);
Graph graph =new Graph();
graph.importGraphDef(graphDef);
Session session =new Session(graph);
String test_sentence ="再也不⽤愁看不起病了,⽼祖宗留下此表!";
System.out.println("sentence: "+test_sentence);
// 输⼊模型的测试语句
int[][] sentenceBuf =getInputFromSentence(test_sentence, wordsMap);
int[] sentLength ={sentenceBuf[0].length};
Tensor inputTensor = ate(sentenceBuf);
Tensor lengthTensor = ate(sentLength);
// 输⼊数据,得到预测结果
Tensor result = session.runner()
.feed("Input_Layer/input_x:0", inputTensor)
.feed("Input_Layer/length:0", lengthTensor)
.fetch("Accuracy/score:0")
.run().get(0);
long[] rshape = result.shape();
int batchSize =(int) rshape[0];
// int nlabels = (int) rshape[1];
float[][] resultArray =new float[batchSize][CLASS_NUM];
System.out.println(resultArray[0][0]+" "+resultArray[0][1]);
}
注意预测结果时同样要保持模型输⼊输出格式名称⼀致。预测的输⼊输出要与模型最初构建时保持⼀致。模型构建时是下⾯这种写法:
java预测时是这种写法:
由于代码是⼀块⼀块分开的,感觉整体不是很连贯,后⾯会完善。

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。