Java操作ElasticSearch,实现SimHash⽐较⽂章相似度
最近⼯作中要求实现相似⽂本查询的功能,我于是决定⽤SimHash实现。
常规思路通常分为以下四步:
1、实现SimHash算法。
2、保存⽂章时,同时保存SimHash为倒排索引。
3、⼊库时或使⽤定时任务,在倒排索引中到碰撞的SimHash,保存为结果表。
4、需要查询⼀篇⽂章的相似⽂章时,根据⽂章ID,查询结果表,到相似⽂章。
不过这⾥有个⼩问题,如果⼀篇多次⼊库的⽂章的SimHash发⽣变化,或者⽂章被删除啥的,结果表可能很难及时更新。
同时ES刚好很擅长查询与维护倒排索引,所以我想能不能直接交给ES帮我维护SimHash的倒排索引,从⽽跳过使⽤结果表呢?
那么以上逻辑会简化到3步:
1、实现SimHash算法。
2、保存⽂章时,同时在ES中保存SimHash字段(和正⽂其它字段⼀起)。
3、需要查询⼀篇⽂章的相似⽂章时,根据⽂章ID查到SimHash值,再去ES查询匹配的其它⽂章ID,不过这⾥需要在服务层做个汉明距离的过滤。
说⼲就⼲,以下是我的实现代码,基于⽹上已有的算法进⾏了⼀些修改,总之给⼤家抛砖引⽟了,如果有做的不好的地⽅还请⼤家指出。
⾸先添加依赖,使⽤HanLP分词,Jsoup提供正⽂HTML标签去除服务。
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.8.1</version>
</dependency>
<dependency>
<groupId>org.jsoup</groupId>
<artifactId>jsoup</artifactId>
<version>1.13.1</version>
</dependency>
接下来是SimHash的核⼼类,我这⾥直接写死了64位SimHash,判重阈值为3:
package ;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.dictionary.stopword.CoreStopWordDictionary;
import com.hankcs.hanlp.segmon.Term;
import com.springbootmonUtil.StringUtils;
import java.math.BigInteger;
import java.util.List;
/**
* 提供SimHash相关的计算服务
*/
public class SimHashService {
public static final BigInteger BIGINT_0 = BigInteger.valueOf(0);
public static final BigInteger BIGINT_1 = BigInteger.valueOf(1);
public static final BigInteger BIGINT_2 = BigInteger.valueOf(2);
public static final BigInteger BIGINT_1000003 = BigInteger.valueOf(1000003);
public static final BigInteger BIGINT_2E64M1 = BIGINT_2.pow(64).subtract(BIGINT_1);
/
**
* 计算⼀段正⽂的simHash
* 警告:修改该⽅法,修改HanLp分词结果(如新增停⽤词),会导致计算出的SimHash发⽣变化。
*
* @param text 需要计算的⽂本
* @return返回simHash,64位的0-1字符串。如果⽂本过短则返回null。
*/
public static String get(String text) {
if (text == null) {
return null;
}
text = veHtml(text); // return Jsoup.parse(text).text();
int sumWeight = 0;
int maxWeight = 0;
int[] bits = new int[64];
List<Term> termList = HanLP.segment(text);
for (Term term : termList) {
String word = term.word;
String nature = String();
if (nature.startsWith("w") || ains(word)) {
// 去除标点符号和停⽤词
continue;
}
BigInteger wordHash = getWordHash(word);
int wordWeight = getWordWeight(word);
if (wordWeight == 0) {
continue;
}
sumWeight += wordWeight;
if (maxWeight < wordWeight) {
maxWeight = wordWeight;
}
// 逐位将计算好的词哈希乘以权重,记录到保存⽤的数组上。
/
/ 如果该位哈希为1,则加上对应的权重,反之减去对应的权重。
for (int i = 0; i < 64; i++) {
BigInteger bitMask = BIGINT_1.shiftLeft(63 - i);
if (wordHash.and(bitMask).signum() != 0) {
bits[i] += wordWeight;
} else {
bits[i] -= wordWeight;
}
}
}
if (3 * maxWeight >= sumWeight || sumWeight < 20) {
/
/ ⽂本太短导致哈希不充分,拒绝返回结果(否则可能会有太多碰撞的⽂档,导致查询性能低下)            // 暂时定为⾄少需要凑齐3个⼤词才允许返回结果
return null;
}
// 将保存的位统计结果降维,处理成0/1字符串并返回
StringBuilder simHashBuilder = new StringBuilder();
for (int i = 0; i < 64; i++) {
if (bits[i] > 0) {
simHashBuilder.append("1");
} else {
simHashBuilder.append("0");
}
}
String();
}
/**
* 获取⼀个单词的哈希值
* 警告:修改该⽅法会导致计算出的SimHash发⽣变化。
*
* @param word 输⼊的单词
* @return返回哈希
*/
private static BigInteger getWordHash(String word) {
if (StringUtils.isBlank(word)) {
return BIGINT_0;
}
char[] sourceArray = CharArray();
// 经过调优,发现左移位数为11-12左右最优
// 在哈希词语主要为长度2的中⽂词时,可以避免⾼位哈希出现明显偏向
// 反之,如果左移位数太⼤,则低位哈希将只和词语最后⼀个字相关
BigInteger hash = BigInteger.valueOf(((long) sourceArray[0]) << 12);
for (char ch : sourceArray) {
BigInteger chInt = BigInteger.valueOf(ch);
hash = hash.multiply(BIGINT_1000003).xor(chInt).and(BIGINT_2E64M1);
}
hash = (BigInteger.valueOf(word.length()));
return hash;
}
/**
* 获取⼀个单词的权重。
* 警告:修改该⽅法会导致计算出的SimHash发⽣变化。
* @param word 输⼊单词
* @return输出权重
*/
private static int getWordWeight(String word) {
if (StringUtils.isBlank(word)) {
return 0;
}
int length = word.length();
if (length == 1) {
// 只有长度为1的词,哈希后位数不够(40位左右),所以权重必须很低,否则容易导致⾼位哈希全部为0。
return 1;
} else if (word.charAt(0) >= 0x3040) {
if (length == 2) {
return 8;
} else {
return 16;
}
} else {
if (length == 2) {
return 2;
} else {
return 4;
}
}
}
/**
* 截取SimHash的⼀部分,转换为short对象
*
* @param simHash 原始SimHash字符串,64位0/1字符
* @param part    需要截取的部分编号
* @return返回Short值
*/
public static Short toShort(String simHash, int part) {
if (simHash == null || part < 0 || part > 3) {
return null;
}
int startBit = part * 16;
int endBit = (part + 1) * 16;
return Integer.valueOf(simHash.substring(startBit, endBit), 2).shortValue();
}
/**
* 将四段Short格式的SimHash拼接成字符串
*
* @param simHashA simHashA,最⾼位
* @param simHashB simHashB
* @param simHashC simHashC
* @param simHashD simHashD,最低位
* @return返回64位0/1格式的SimHash
*/
public static String toSimHash(Short simHashA, Short simHashB, Short simHashC, Short simHashD) {
return toSimHash(simHashA) + toSimHash(simHashB) + toSimHash(simHashC) + toSimHash(simHashD);
}
/**
* 将⼀段Short格式的SimHash拼接成字符串
*
* @param simHashX 需要转换的Short格式SimHash
* @return返回16位0/1格式的SimHash
*/
public static String toSimHash(Short simHashX) {
StringBuilder simHashBuilder = new String(simHashX & 65535, 2));
int fill0Count = 16 - simHashBuilder.length();
for (int i = 0; i < fill0Count; i++) {
simHashBuilder.insert(0, "0");
}
String();
}
/**
* ⽐较两组SimHash(⼀组为64位0/1字符串,⼀组为四组Short),计算汉明距离
*
* @param simHash  待⽐较的SimHash(X),64位0/1字符串
* @param simHashA 待⽐较的SimHash(Y),Short格式,最⾼位
* @param simHashB 待⽐较的SimHash(Y),Short格式
* @param simHashC 待⽐较的SimHash(Y),Short格式
* @param simHashD 待⽐较的SimHash(Y),Short格式,最低位
* @return返回汉明距离
*/
public static int hammingDistance(String simHash, Short simHashA, Short simHashB, Short simHashC, Short simHashD) { if (simHash == null || simHashA == null || simHashB == null || simHashC == null || simHashD == null) {
return -1;
int hammingDistance = 0;
for (int part = 0; part < 4; part++) {
Short simHashX = toShort(simHash, part);
Short simHashY = null;
switch (part) {
case 0:
simHashY = simHashA;
break;
case 1:
simHashY = simHashB;
break;
case 2:
simHashY = simHashC;
break;
case 3:
simHashY = simHashD;
break;
}
hammingDistance += hammingDistance(simHashX, simHashY);
}
return hammingDistance;
}
/**
* ⽐较两个Short格式的SimHash的汉明距离
*
* @param simHashX 待⽐较的SimHashX
* @param simHashY 待⽐较的SimHashY
* @return返回汉明距离
*/
public static int hammingDistance(Short simHashX, Short simHashY) {
if (simHashX == null || simHashY == null) {
return -1;
}
int hammingDistance = 0;
int xorResult = (simHashX ^ simHashY) & 65535;
while (xorResult != 0) {
xorResult = xorResult & (xorResult - 1);
hammingDistance += 1;
}
return hammingDistance;
}
}
ES索引中需要新增4个SimHash相关的字段:
"simHashA": {
"type": "short"
},
"simHashB": {
"type": "short"
},
"simHashC": {
"type": "short"
},
"simHashD": {
"type": "short"
}
最后是ES查询逻辑,根据传⼊的SimHash,先使⽤ES到⾄少⼀组SimHash相等的⽂档,然后在Java代码中⽐较剩下三组是否满⾜要求。/**
* 根据SimHash,查询相似的⽂章。
*
* @param indexNames 需要查询的索引名称(允许多个)
* @param simHashA  simHashA的值
* @param simHashB  simHashB的值
* @param simHashC  simHashC的值
* @param simHashD  simHashD的值
* @return返回相似⽂章RowKey列表。
*/
public List<String> searchBySimHash(String indexNames, Short simHashA, short simHashB, short simHashC, short simHashD) {
String simHash = SimHash(simHashA, simHashB, simHashC, simHashD);
return searchBySimHash(indexNames, simHash);
}
* 根据SimHash,查询相似的⽂章。
*
* @param indexNames 需要查询的索引名称(允许多个)
* @param simHash    需要查询的SimHash (格式:64位⼆进制字符串)
* @return返回相似⽂章RowKey列表。
*/
public List<String> searchBySimHash(String indexNames, String simHash) {
List<String> resultList = new ArrayList<>();
if (simHash == null) {
return resultList;
}
try {
String scrollId = "";
while (true) {
if (scrollId == null) {
break;
}
SearchResponse response = null;
if (scrollId.isEmpty()) {
// ⾸次请求,正常查询
SearchRequest request = new SearchRequest(indexNames.split(","));
BoolQueryBuilder bqBuilder = QueryBuilders.boolQuery();
bqBuilder.Query("simHashA", Short(simHash, 0)));
bqBuilder.Query("simHashB", Short(simHash, 1)));
bqBuilder.Query("simHashC", Short(simHash, 2)));
bqBuilder.Query("simHashD", Short(simHash, 3)));
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10000);
sourceBuilder.query(bqBuilder);
sourceBuilder.from(0);
sourceBuilder.size(10000);
sourceBuilder.timeout(TimeValue.timeValueSeconds(60));
sourceBuilder.fetchSource(new String[]{"hId", "simHashA", "simHashB", "simHashC", "simHashD"}, new String[]{});
sourceBuilder.sort("publishDate", SortOrder.DESC);
request.source(sourceBuilder);
request.scroll(TimeValue.timeValueSeconds(60));
response = client.search(request, RequestOptions.DEFAULT);
} else {
// 之后请求,⾛游标查询
SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId).scroll(TimeValue.timeValueMinutes(10));                    response = client.scroll(searchScrollRequest, RequestOptions.DEFAULT);
}
if (response != null && Hits().getHits().length > 0) {
// 查到的记录必然有⼀组simHashX与输⼊相同,但需要合并确认总数是否⼩于阈值
// 很可能有⼏万的命中,但最终过滤完只剩下⼏条数据,最终留下ID
for (SearchHit hit : Hits().getHits()) {
Map<String, Object> sourceAsMap = SourceAsMap();
String hId = String.("hId"));
Short simHashA = Short.("simHashA").toString());
Short simHashB = Short.("simHashB").toString());
Short simHashC = Short.("simHashC").toString());
Short simHashD = Short.("simHashD").toString());
int hammingDistance = SimHashService.hammingDistance(simHash, simHashA, simHashB, simHashC, simHashD);
if (hammingDistance < 4) {
System.out.println(hammingDistance + "\t" + hId);
resultList.("hId").toString());
}
}
scrollId = ScrollId();
} else {
break;
字符串截取右3位}
}
} catch (IOException e) {
e.printStackTrace();
}
return resultList;
}
⽬前在ES单节点保存90万条数据(其中10万含有SimHash字段)的查询延迟⼤约在0.2秒左右。
总之我把我的思路分享给⼤家,可能我代码写的⽐较烂,还请⼤家指点。

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。