From e946138503bb8e050c08c0af1ba81358ea3a69c8 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Wed, 5 Jun 2024 18:35:06 +0800 Subject: [PATCH] (improvement)(Headless) Filtering based on dataSetIds during Mapper detection Compatible with term --- .../com/hankcs/hanlp/LoadRemoveService.java | 15 +----- .../collection/trie/bintrie/BaseNode.java | 21 ++++---- .../core/chat/knowledge/SearchService.java | 50 +++++++++++-------- 3 files changed, 40 insertions(+), 46 deletions(-) diff --git a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java index 745555398c..b31e87d281 100644 --- a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java +++ b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java @@ -10,7 +10,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; @Data @Slf4j @@ -19,23 +18,11 @@ public class LoadRemoveService { @Value("${mapper.remove.nature.prefix:}") private String mapperRemoveNaturePrefix; - public List removeNatures(List value, Set detectModelIds) { + public List removeNatures(List value) { if (CollectionUtils.isEmpty(value)) { return value; } List resultList = new ArrayList<>(value); - if (!CollectionUtils.isEmpty(detectModelIds)) { - resultList.removeIf(nature -> { - if (Objects.isNull(nature)) { - return false; - } - Long modelId = getDataSetId(nature); - if (Objects.nonNull(modelId)) { - return !detectModelIds.contains(modelId); - } - return false; - }); - } if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) { resultList.removeIf(nature -> { if (Objects.isNull(nature)) { diff --git a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java index 0f50c76ba0..f81ffb2930 100644 --- a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java +++ b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java @@ -2,6 +2,9 @@ import com.hankcs.hanlp.LoadRemoveService; import com.hankcs.hanlp.corpus.io.ByteArray; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.DataOutputStream; import java.io.IOException; import java.io.ObjectInput; @@ -14,8 +17,6 @@ import java.util.Objects; import java.util.Queue; import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public abstract class BaseNode implements Comparable { @@ -286,12 +287,12 @@ public String toString() { + '}'; } - public void walkNode(Set> entrySet, Set detectModelIds) { + public void walkNode(Set> entrySet) { if (status == Status.WORD_MIDDLE_2 || status == Status.WORD_END_3) { - logger.debug("detectModelIds:{},before:{}", detectModelIds, value.toString()); - List natures = new LoadRemoveService().removeNatures((List) value, detectModelIds); + logger.debug("walkNode before:{}",value.toString()); + List natures = new LoadRemoveService().removeNatures((List) value); String name = this.prefix != null ? this.prefix + c : "" + c; - logger.debug("name:{},after:{},natures:{}", name, (List) value, natures); + logger.debug("walkNode name:{},after:{},natures:{}", name, (List) value, natures); entrySet.add(new TrieEntry(name, (V) natures)); } } @@ -300,21 +301,17 @@ public void walkNode(Set> entrySet, Set detectModelId * walk limit * @param sb * @param entrySet - * @param limit */ - public void walkLimit(StringBuilder sb, Set> entrySet, int limit, Set detectModelIds) { + public void walkLimit(StringBuilder sb, Set> entrySet) { Queue queue = new ArrayDeque<>(); this.prefix = sb.toString(); queue.add(this); while (!queue.isEmpty()) { - if (entrySet.size() >= limit) { - break; - } BaseNode root = queue.poll(); if (root == null) { continue; } - root.walkNode(entrySet, detectModelIds); + root.walkNode(entrySet); if (root.child == null) { continue; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java index 4174b82f18..55c1d0f35f 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/SearchService.java @@ -48,22 +48,16 @@ public static List prefixSearch(String key, int limit, Map prefixSearch(String key, int limit, BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - Set>> result = prefixSearchLimit(key, limit, binTrie, - modelIdToDataSetIds, detectDataSetIds); + Set>> result = prefixSearch(key, binTrie); List hanlpMapResults = result.stream().map( entry -> { String name = entry.getKey().replace("#", " "); return new HanlpMapResult(name, entry.getValue(), key); } ).sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .limit(SEARCH_SIZE) .collect(Collectors.toList()); - for (HanlpMapResult hanlpMapResult : hanlpMapResults) { - List natures = hanlpMapResult.getNatures().stream() - .map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds)) - .flatMap(Collection::stream).collect(Collectors.toList()); - hanlpMapResult.setNatures(natures); - } + hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, + detectDataSetIds, limit); return hanlpMapResults; } @@ -80,11 +74,8 @@ public static List suffixSearch(String key, int limit, Map suffixSearch(String key, int limit, BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - - Set>> result = prefixSearchLimit(key, limit, binTrie, modelIdToDataSetIds, - detectDataSetIds); - - return result.stream().map( + Set>> result = prefixSearch(key, binTrie); + List hanlpMapResults = result.stream().map( entry -> { String name = entry.getKey().replace("#", " "); List natures = entry.getValue().stream() @@ -94,15 +85,34 @@ public static List suffixSearch(String key, int limit, BinTrie -(b.getName().length() - a.getName().length())) - .limit(SEARCH_SIZE) .collect(Collectors.toList()); + return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit); } - private static Set>> prefixSearchLimit(String key, int limit, - BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { - - Set detectModelIds = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds); + private static List transformAndFilterByDataSet(List hanlpMapResults, + Map> modelIdToDataSetIds, + Set detectDataSetIds, int limit) { + return hanlpMapResults.stream().peek(hanlpMapResult -> { + List natures = hanlpMapResult.getNatures().stream() + .map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds)) + .flatMap(Collection::stream) + .filter(nature -> { + if (CollectionUtils.isEmpty(detectDataSetIds)) { + return true; + } + Long dataSetId = NatureHelper.getDataSetId(nature); + if (dataSetId != null) { + return detectDataSetIds.contains(dataSetId); + } + return false; + }).collect(Collectors.toList()); + hanlpMapResult.setNatures(natures); + }).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures())) + .limit(limit).collect(Collectors.toList()); + } + private static Set>> prefixSearch(String key, + BinTrie> binTrie) { key = key.toLowerCase(); Set>> entrySet = new TreeSet>>(); @@ -122,7 +132,7 @@ private static Set>> prefixSearchLimit(String key if (branch == null) { return entrySet; } - branch.walkLimit(sb, entrySet, limit, detectModelIds); + branch.walkLimit(sb, entrySet); return entrySet; }