Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]加权融合策略 #19

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/main/java/com/search/docsearch/constant/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,15 @@ public class Constants {


public static final String HTTPS_PREFIX = "https://";


/**
* Maxsocre that used to normlize the result
*/
public static final int MAX_SCORE = 10000000;

/**
* Min socre that used to normlize the result
*/
public static final int MIN_SCORE = -1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright (c) 2024 openEuler Community
EasySoftware is licensed under the Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
*/
package com.search.docsearch.factorys;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import com.search.docsearch.multirecall.recall.cstrategy.GSearchStrategy;

@Component
public class HttpConnectFactory {

/**
* logger.
*/
private static final Logger LOGGER = LoggerFactory.getLogger(GSearchStrategy.class);

/**
* create a http connection with a url string
* @param urlString
* @return
* @throws IOException
*/
public HttpURLConnection createConnection(String urlString) throws IOException {
URL url = new URL(urlString);
return (HttpURLConnection) url.openConnection();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.search.docsearch.constant.Constants;
import com.search.docsearch.utils.MergeUtil;

public class DataComposite implements Component {

Expand Down Expand Up @@ -110,8 +114,40 @@ public Map<String, Object> mergeResult(){
aresList.add(pos, bresList.get(pos));
}
}

ares.put("records", aresList);
return ares;
}
}

/**
* merge the other recall results into one way, based one the index 0 of children
*
* @return the merged result lists
*/
public List<Map<String, Object>> weightedMerge(int pageSize){
List<Map<String, Object>> mergeList = new ArrayList<>();

for (Component recall : this.children){
double minScore = Constants.MAX_SCORE;
double maxScore = Constants.MIN_SCORE;
List<Map<String, Object>> rcords = (List<Map<String, Object>>) recall.getResList().get("records");
// find min and max
for (Map<String, Object> entity : rcords) {
double score = (double) entity.get("score");
minScore = Math.min(score,minScore);
maxScore = Math.max(score, maxScore);
}
// do norm
for (Map<String, Object> entity : rcords) {
double score = (double) entity.get("score");
double normedScore = MergeUtil.normalize(score, minScore, maxScore);
entity.put("score", normedScore);
mergeList.add(entity);
}
}

mergeList = mergeList.stream().sorted((a, b) -> Double.compare((Double) b.get("score"), (Double) a.get("score"))).collect(Collectors.toList());

return mergeList.subList(0, Math.min(pageSize, mergeList.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public DataComposite executeMultiSearch(SearchCondition condition) {
//do recall, the validtaion of condition should implement by concrte strategy
Component recallRes = searchStgy.search(condition);
//in case one way does't recall anything
if (recallRes != null){
if (recallRes != null && recallRes.getResList().size() > 0){
res.add(recallRes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.search.docsearch.entity.vo.GoogleSearchParams;
import com.search.docsearch.entity.vo.SearchCondition;
import com.search.docsearch.except.ServiceImplException;
import com.search.docsearch.factorys.HttpConnectFactory;
import com.search.docsearch.multirecall.composite.Component;
import com.search.docsearch.multirecall.composite.cdata.GRecallData;
import com.search.docsearch.multirecall.recall.SearchStrategy;
Expand All @@ -47,8 +48,14 @@ public class GSearchStrategy implements SearchStrategy {
*/
private GoogleSearchProperties gProperties;

public GSearchStrategy(GoogleSearchProperties gProperties) {
/**
* insert httpConnectionFactory to creat a URL
*/
private HttpConnectFactory httpConnectFactory;

public GSearchStrategy(GoogleSearchProperties gProperties, HttpConnectFactory httpConnectFactory) {
this.gProperties = gProperties;
this.httpConnectFactory = httpConnectFactory;
}

/**
Expand Down Expand Up @@ -78,6 +85,9 @@ public Component search(SearchCondition condition) {
* @throws IOException
*/
private Component searchByCondition(SearchCondition condition) throws ServiceImplException, IOException {
// google search 处理无效字符
condition.setKeyword(condition.getKeyword().replace(" ", ""));
condition.setKeyword(condition.getKeyword().replace(".", ""));
GoogleSearchParams googleSearchParams = new GoogleSearchParams();
googleSearchParams.setKeyWord(condition.getKeyword());
if ("en".equals(condition.getLang())) {
Expand All @@ -94,9 +104,8 @@ private Component searchByCondition(SearchCondition condition) throws ServiceImp
int count = 0;
String keyWord = googleSearchParams.getKeyWord();
String urlString = googleSearchParams.buildUrl(gProperties.getUrl(), gProperties.getKey(), gProperties.getCx());
// 创建URL对象
URL url = new URL(urlString);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
// 创建connection对象
HttpURLConnection connection = httpConnectFactory.createConnection(urlString);
try {
connection.setRequestMethod("GET");
int timeout = 15000; // 设置超时时间为15秒
Expand Down Expand Up @@ -129,7 +138,7 @@ private Component searchByCondition(SearchCondition condition) throws ServiceImp
} else {
map.put("lang", "zh");
}
map.put("score", 5000 - (count + start) * 50);
map.put("score", (double) (5000 - (count + start) * 50));
count++;
data.add(map);
}
Expand All @@ -140,9 +149,10 @@ private Component searchByCondition(SearchCondition condition) throws ServiceImp
}
} else {
LOGGER.error("GET request not worked, response code: {}", responseCode);
return null;
}
} catch (Exception e) {
LOGGER.error(e.getMessage());
LOGGER.error("google search error: {}", e.getMessage());
} finally {
connection.disconnect();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.search.docsearch.entity.vo.SearchTags;
import com.search.docsearch.except.ServiceException;
import com.search.docsearch.except.ServiceImplException;
import com.search.docsearch.factorys.HttpConnectFactory;
import com.search.docsearch.multirecall.composite.DataComposite;
import com.search.docsearch.multirecall.recall.MultiSearchContext;
import com.search.docsearch.multirecall.recall.cstrategy.EsSearchStrategy;
Expand Down Expand Up @@ -112,8 +113,17 @@ public class SearchServiceImpl implements SearchService {
@Value("${api.npsApi}")
private String npsApi;

/**
* insert google serach properties
*/
@Autowired
private GoogleSearchProperties gProperties;

/**
* insert httpConnectionFactory to creat a URL
*/
@Autowired
private HttpConnectFactory httpConnectFactory;

@Autowired
private EsfunctionScoreConfig esfunctionScoreConfig;
Expand Down Expand Up @@ -213,7 +223,7 @@ public Map<String, Object> getSuggestion(String keyword, String lang) throws Ser
public Map<String, Object> searchByCondition(SearchCondition condition) throws ServiceImplException {
//create es search strategy
EsSearchStrategy esRecall = new EsSearchStrategy(restHighLevelClient,mySystem.index,trie,esfunctionScoreConfig);
GSearchStrategy gRecall = new GSearchStrategy(gProperties);
GSearchStrategy gRecall = new GSearchStrategy(gProperties, httpConnectFactory);
MultiSearchContext multirecall = new MultiSearchContext();
//set es search into search contex
multirecall.setSearchStrategy(esRecall);
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/com/search/docsearch/utils/MergeUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2024 openEuler Community
EasySoftware is licensed under the Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
*/
package com.search.docsearch.utils;

public class MergeUtil {

/**
* normalize the score according to their own score
*
* @return the normalied score of search results
*/
public static double normalize(double score, double minScore, double maxScore) {
// 检查范围是否有效
if (maxScore <= minScore) {
throw new IllegalArgumentException("maxScore 必须大于 minScore");
}
// 归一化公式 (score - minScore) / (maxScore - minScore)
return (score - minScore) / (maxScore - minScore);
}
}
52 changes: 51 additions & 1 deletion src/test/java/com/search/docsearch/CompositeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package com.search.docsearch;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;

Expand All @@ -20,7 +20,11 @@
import com.search.docsearch.multirecall.composite.cdata.EsRecallData;
import com.search.docsearch.multirecall.composite.cdata.GRecallData;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@SpringBootTest
public class CompositeTest {

Expand Down Expand Up @@ -103,4 +107,50 @@ void testFliteringRecallWithError() {
assertEquals("error when process the recall res",exception.getMessage());
}

/**
* 测试: normlize加权融合测试
*/
@Test
void testWeightedMerge() {
// 设置mockComponent1的返回数据
List<Map<String, Object>> records1 = new ArrayList<>();
Map<String, Object> record1_1 = new HashMap<>();
record1_1.put("score", 3.0);
records1.add(record1_1);

Map<String, Object> record1_2 = new HashMap<>();
record1_2.put("score", 1.0);
records1.add(record1_2);

// 设置mockComponent2的返回数据
List<Map<String, Object>> records2 = new ArrayList<>();
Map<String, Object> record2_1 = new HashMap<>();
record2_1.put("score", 2.5);
records2.add(record2_1);

Map<String, Object> record2_2 = new HashMap<>();
record2_2.put("score", 4.0);
records2.add(record2_2);

DataComposite dataComposite = new DataComposite();

Component mockComponent1 = new EsRecallData(Collections.singletonMap("records", records1));
Component mockComponent2 = new EsRecallData(Collections.singletonMap("records", records2));

dataComposite.add(mockComponent1);
dataComposite.add(mockComponent2);
// 校验是否按pagesize返回正确个数
int pageSize = 3;
List<Map<String, Object>> result = dataComposite.weightedMerge(pageSize);
assertEquals(pageSize, result.size());

// 验证结果是否按分数降序排列
for (int i = 0; i < result.size() - 1; i++) {
double score1 = (Double) result.get(i).get("score");
double score2 = (Double) result.get(i + 1).get("score");
assertTrue(score1 >= score2, "Results should be sorted in descending order by score");
}
}


}
15 changes: 12 additions & 3 deletions src/test/java/com/search/docsearch/GSearchStrategyTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -28,6 +29,7 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.junit.jupiter.api.BeforeEach;
import com.search.docsearch.entity.vo.SearchCondition;
import com.search.docsearch.factorys.HttpConnectFactory;
import com.search.docsearch.multirecall.composite.Component;
import com.search.docsearch.multirecall.composite.cdata.GRecallData;
import com.search.docsearch.multirecall.recall.cstrategy.GSearchStrategy;
Expand All @@ -41,6 +43,12 @@ public class GSearchStrategyTests {
@Mock
private GoogleSearchProperties gProperties;

/**
* insert httpConnectionFactory to creat a URL
*/
@Mock
HttpConnectFactory httpConnectFactory;

/**
* the search service
*/
Expand All @@ -64,8 +72,7 @@ void setUp() {
searchCondition.setLang("en");
searchCondition.setPage(1);
searchCondition.setPageSize(10);

gSearchStrategy = new GSearchStrategy(gProperties);
//gSearchStrategy = new GSearchStrategy(gProperties, httpConnectFactory);
}
/**
* 测试:获得google搜索结果
Expand All @@ -77,10 +84,10 @@ public void testGoogleSearchApi() throws IOException {

// Mock the http connection
HttpURLConnection mockConnection = mock(HttpURLConnection.class);
when(mockConnection.getRequestMethod()).thenReturn("GET");
when(mockConnection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);
String mockResponse = "{\"items\":[{\"title\":\"openeuler开源社区\",\"link\":\"http://euler.com\",\"snippet\":\"openeuler提供了一系列...\"}]}";
when(mockConnection.getInputStream()).thenReturn(new ByteArrayInputStream(mockResponse.getBytes(StandardCharsets.UTF_8)));
when(httpConnectFactory.createConnection(anyString())).thenReturn(mockConnection);

// Perform the search
Component result = gSearchStrategy.search(searchCondition);
Expand Down Expand Up @@ -109,6 +116,8 @@ void testGetGoogleSearchApiWithError() throws IOException {

// Mock the http connection
HttpURLConnection mockConnection = mock(HttpURLConnection.class);
when(mockConnection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
when(httpConnectFactory.createConnection(anyString())).thenReturn(mockConnection);

// Perform the search
Component result = gSearchStrategy.search(searchCondition);
Expand Down
Loading