Skip to content

Commit

Permalink
Merge pull request #24036 from vespa-engine/havardpe/multi-bitvector-…
Browse files Browse the repository at this point in the history
…global-filter

add support for multi-bitvector global filter
  • Loading branch information
baldersheim authored Sep 13, 2022
2 parents 2f32564 + 295e925 commit 04c3414
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 24 deletions.
5 changes: 3 additions & 2 deletions searchlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ vespa_define_module(
src/tests/attribute/multi_value_mapping
src/tests/attribute/multi_value_read_view
src/tests/attribute/posting_list_merger
src/tests/attribute/posting_store
src/tests/attribute/postinglist
src/tests/attribute/postinglistattribute
src/tests/attribute/posting_store
src/tests/attribute/reference_attribute
src/tests/attribute/save_target
src/tests/attribute/searchable
Expand All @@ -112,8 +112,8 @@ vespa_define_module(
src/tests/common/summaryfeatures
src/tests/diskindex/bitvector
src/tests/diskindex/diskindex
src/tests/diskindex/fieldwriter
src/tests/diskindex/field_length_scanner
src/tests/diskindex/fieldwriter
src/tests/diskindex/fusion
src/tests/diskindex/pagedict4
src/tests/docstore/chunk
Expand Down Expand Up @@ -193,6 +193,7 @@ vespa_define_module(
src/tests/queryeval/equiv
src/tests/queryeval/fake_searchable
src/tests/queryeval/getnodeweight
src/tests/queryeval/global_filter
src/tests/queryeval/matching_elements_search
src/tests/queryeval/monitoring_search_iterator
src/tests/queryeval/multibitvectoriterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ TEST_F("NN blueprint handles empty filter (post-filtering)", NearestNeighborBlue
TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint();
auto filter = search::BitVector::create(11);
auto filter = search::BitVector::create(1,11);
filter->setBit(3);
filter->invalidateCachedCount();
auto strong_filter = GlobalFilter::create(std::move(filter));
Expand All @@ -1128,7 +1128,7 @@ TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlue
TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint();
auto filter = search::BitVector::create(11);
auto filter = search::BitVector::create(1,11);
filter->setBit(1);
filter->setBit(3);
filter->setBit(5);
Expand All @@ -1147,7 +1147,7 @@ TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBluepr
TEST_F("NN blueprint handles strong filter triggering exact search", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint(true, 0.2);
auto filter = search::BitVector::create(11);
auto filter = search::BitVector::create(1,11);
filter->setBit(3);
filter->invalidateCachedCount();
auto strong_filter = GlobalFilter::create(std::move(filter));
Expand Down
9 changes: 9 additions & 0 deletions searchlib/src/tests/queryeval/global_filter/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_executable(searchlib_queryeval_global_filter_test_app TEST
SOURCES
global_filter_test.cpp
DEPENDS
searchlib
GTest::GTest
)
vespa_add_test(NAME searchlib_queryeval_global_filter_test_app COMMAND searchlib_queryeval_global_filter_test_app)
139 changes: 139 additions & 0 deletions searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/util/require.h>
#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/searchlib/common/bitvector.h>

#include <gmock/gmock.h>
#include <vector>

using namespace testing;

using search::BitVector;
using search::queryeval::GlobalFilter;
using vespalib::RequireFailedException;

TEST(GlobalFilterTest, create_can_make_inactive_filter) {
auto filter = GlobalFilter::create();
EXPECT_FALSE(filter->is_active());
}

void verify(const GlobalFilter &filter) {
EXPECT_TRUE(filter.is_active());
EXPECT_EQ(filter.size(), 100);
EXPECT_EQ(filter.count(), 3);
for (size_t i = 1; i < 100; ++i) {
if (i == 11 || i == 22 || i == 33) {
EXPECT_TRUE(filter.check(i));
} else {
EXPECT_FALSE(filter.check(i));
}
}
}

TEST(GlobalFilterTest, create_can_make_test_filter) {
auto docs = std::vector<uint32_t>({11,22,33});
auto filter = GlobalFilter::create(docs, 100);
verify(*filter);
}

TEST(GlobalFilterTest, test_filter_requires_docs_in_order) {
auto docs = std::vector<uint32_t>({11,33,22});
EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
}

TEST(GlobalFilterTest, test_filter_requires_docs_in_range) {
auto docs = std::vector<uint32_t>({11,22,133});
EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
}

TEST(GlobalFilterTest, test_filter_docid_0_not_allowed) {
auto docs = std::vector<uint32_t>({0,22,33});
EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
}

TEST(GlobalFilterTest, create_can_make_single_bitvector_filter) {
auto bits = BitVector::create(1, 100);
bits->setBit(11);
bits->setBit(22);
bits->setBit(33);
bits->invalidateCachedCount();
EXPECT_EQ(bits->countTrueBits(), 3);
auto filter = GlobalFilter::create(std::move(bits));
verify(*filter);
}

TEST(GlobalFilterTest, global_filter_pointer_guard) {
auto inactive = GlobalFilter::create();
auto active = GlobalFilter::create(BitVector::create(1,100));
EXPECT_TRUE(active->is_active());
EXPECT_FALSE(inactive->is_active());
EXPECT_TRUE(active->ptr_if_active() == active.get());
EXPECT_TRUE(inactive->ptr_if_active() == nullptr);
}

TEST(GlobalFilterTest, create_can_make_multi_bitvector_filter) {
std::vector<std::unique_ptr<BitVector>> bits;
bits.push_back(BitVector::create(1, 11));
bits.push_back(BitVector::create(11, 23));
bits.push_back(BitVector::create(23, 25));
bits.push_back(BitVector::create(25, 100));
bits[1]->setBit(11);
bits[1]->setBit(22);
bits[3]->setBit(33);
for (const auto &v: bits) {
v->invalidateCachedCount();
}
auto filter = GlobalFilter::create(std::move(bits));
verify(*filter);
}

TEST(GlobalFilterTest, multi_bitvector_filter_with_empty_vectors) {
std::vector<std::unique_ptr<BitVector>> bits;
bits.push_back(BitVector::create(1, 11));
bits.push_back(BitVector::create(11, 23));
bits.push_back(BitVector::create(23, 23));
bits.push_back(BitVector::create(23, 23));
bits.push_back(BitVector::create(23, 25));
bits.push_back(BitVector::create(25, 100));
bits[1]->setBit(11);
bits[1]->setBit(22);
bits[5]->setBit(33);
for (const auto &v: bits) {
v->invalidateCachedCount();
}
auto filter = GlobalFilter::create(std::move(bits));
verify(*filter);
}

TEST(GlobalFilterTest, multi_bitvector_filter_with_no_vectors) {
std::vector<std::unique_ptr<BitVector>> bits;
auto filter = GlobalFilter::create(std::move(bits));
EXPECT_TRUE(filter->is_active());
EXPECT_EQ(filter->size(), 0);
EXPECT_EQ(filter->count(), 0);
}

TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_gaps) {
std::vector<std::unique_ptr<BitVector>> bits;
bits.push_back(BitVector::create(1, 11));
bits.push_back(BitVector::create(12, 100));
EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
}

TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_overlap) {
std::vector<std::unique_ptr<BitVector>> bits;
bits.push_back(BitVector::create(1, 11));
bits.push_back(BitVector::create(10, 100));
EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
}

TEST(GlobalFilterTest, multi_bitvector_filter_requires_correct_order) {
std::vector<std::unique_ptr<BitVector>> bits;
bits.push_back(BitVector::create(11, 100));
bits.push_back(BitVector::create(1, 11));
EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
}

GTEST_MAIN_RUN_ALL_TESTS()
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,7 @@ struct Fixture

void setFilter(std::vector<uint32_t> docids) {
uint32_t sz = _attr->getNumDocs();
auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
bit_vector->setBit(id);
}
_global_filter = GlobalFilter::create(std::move(bit_vector));
_global_filter = GlobalFilter::create(docids, sz);
}

void setTensor(uint32_t docId, const Value &tensor) {
Expand Down
15 changes: 3 additions & 12 deletions searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ class HnswIndexTest : public ::testing::Test {

~HnswIndexTest() {}

const GlobalFilter *global_filter_ptr() const {
return global_filter->is_active() ? global_filter.get() : nullptr;
}

void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
Expand All @@ -110,12 +106,7 @@ class HnswIndexTest : public ::testing::Test {
}
void set_filter(std::vector<uint32_t> docids) {
uint32_t sz = 10;
auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
bit_vector->setBit(id);
}
global_filter = GlobalFilter::create(std::move(bit_vector));
global_filter = GlobalFilter::create(docids, sz);
}
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
Expand Down Expand Up @@ -149,7 +140,7 @@ class HnswIndexTest : public ::testing::Test {
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
Expand All @@ -170,7 +161,7 @@ class HnswIndexTest : public ::testing::Test {
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid);
uint32_t k = 3;
auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
Expand Down
64 changes: 63 additions & 1 deletion searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "global_filter.h"
#include <vespa/vespalib/util/require.h>

namespace search::queryeval {

Expand All @@ -23,6 +24,31 @@ struct BitVectorFilter : public GlobalFilter {
bool check(uint32_t docid) const override { return vector->testBit(docid); }
};

struct MultiBitVectorFilter : public GlobalFilter {
std::vector<std::unique_ptr<BitVector>> vectors;
std::vector<uint32_t> splits;
uint32_t total_size;
uint32_t total_count;
MultiBitVectorFilter(std::vector<std::unique_ptr<BitVector>> vectors_in,
std::vector<uint32_t> splits_in,
uint32_t total_size_in,
uint32_t total_count_in)
: vectors(std::move(vectors_in)),
splits(std::move(splits_in)),
total_size(total_size_in),
total_count(total_count_in) {}
bool is_active() const override { return true; }
uint32_t size() const override { return total_size; }
uint32_t count() const override { return total_count; }
bool check(uint32_t docid) const override {
size_t i = 0;
while ((i < splits.size()) && (docid >= splits[i])) {
++i;
}
return vectors[i]->testBit(docid);
}
};

}

GlobalFilter::GlobalFilter() = default;
Expand All @@ -34,8 +60,44 @@ GlobalFilter::create() {
}

std::shared_ptr<GlobalFilter>
GlobalFilter::create(std::unique_ptr<BitVector> vector) {
GlobalFilter::create(std::vector<uint32_t> docids, uint32_t size)
{
uint32_t prev = 0;
auto bits = BitVector::create(1, size);
for (uint32_t docid: docids) {
REQUIRE(docid > prev);
REQUIRE(docid < size);
bits->setBit(docid);
prev = docid;
}
bits->invalidateCachedCount();
return create(std::move(bits));
}

std::shared_ptr<GlobalFilter>
GlobalFilter::create(std::unique_ptr<BitVector> vector)
{
return std::make_shared<BitVectorFilter>(std::move(vector));
}

std::shared_ptr<GlobalFilter>
GlobalFilter::create(std::vector<std::unique_ptr<BitVector>> vectors)
{
uint32_t total_size = 0;
uint32_t total_count = 0;
std::vector<uint32_t> splits;
for (size_t i = 0; i < vectors.size(); ++i) {
bool last = ((i + 1) == vectors.size());
total_count += vectors[i]->countTrueBits();
if (last) {
total_size = vectors[i]->size();
} else {
REQUIRE_EQ(vectors[i]->size(), vectors[i + 1]->getStartIndex());
splits.push_back(vectors[i]->size());
}
}
return std::make_shared<MultiBitVectorFilter>(std::move(vectors), std::move(splits),
total_size, total_count);
}

}
6 changes: 6 additions & 0 deletions searchlib/src/vespa/searchlib/queryeval/global_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@ class GlobalFilter : public std::enable_shared_from_this<GlobalFilter>
virtual bool check(uint32_t docid) const = 0;
virtual ~GlobalFilter();

const GlobalFilter *ptr_if_active() const {
return is_active() ? this : nullptr;
}

static std::shared_ptr<GlobalFilter> create();
static std::shared_ptr<GlobalFilter> create(std::vector<uint32_t> docids, uint32_t size);
static std::shared_ptr<GlobalFilter> create(std::unique_ptr<BitVector> vector);
static std::shared_ptr<GlobalFilter> create(std::vector<std::unique_ptr<BitVector>> vectors);
};

} // namespace

0 comments on commit 04c3414

Please sign in to comment.