-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
61db7c9
commit d1ca352
Showing
56 changed files
with
16,601 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
set(FAISS_TEST_SRC | ||
test_binary_flat.cpp | ||
test_dealloc_invlists.cpp | ||
test_ivfpq_codec.cpp | ||
test_ivfpq_indexing.cpp | ||
test_lowlevel_ivf.cpp | ||
test_merge.cpp | ||
test_omp_threads.cpp | ||
test_ondisk_ivf.cpp | ||
test_pairs_decoding.cpp | ||
test_params_override.cpp | ||
test_pq_encoding.cpp | ||
test_sliding_ivf.cpp | ||
test_threaded_index.cpp | ||
test_transfer_invlists.cpp | ||
test_mem_leak.cpp | ||
test_cppcontrib_sa_decode.cpp | ||
test_cppcontrib_uintreader.cpp | ||
test_simdlib.cpp | ||
) | ||
|
||
add_executable(faiss_test ${FAISS_TEST_SRC}) | ||
|
||
if(FAISS_OPT_LEVEL STREQUAL "avx2") | ||
if(NOT WIN32) | ||
target_compile_options(faiss_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma>) | ||
else() | ||
target_compile_options(faiss_test PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>) | ||
endif() | ||
target_link_libraries(faiss_test PRIVATE faiss_avx2) | ||
else() | ||
target_link_libraries(faiss_test PRIVATE faiss) | ||
endif() | ||
|
||
include(FetchContent) | ||
FetchContent_Declare(googletest | ||
URL "https://github.com/google/googletest/archive/release-1.12.1.tar.gz") | ||
set(BUILD_GMOCK CACHE BOOL OFF) | ||
set(INSTALL_GTEST CACHE BOOL OFF) | ||
FetchContent_MakeAvailable(googletest) | ||
|
||
find_package(OpenMP REQUIRED) | ||
|
||
target_link_libraries(faiss_test PRIVATE | ||
OpenMP::OpenMP_CXX | ||
gtest_main | ||
) | ||
|
||
# Defines `gtest_discover_tests()`. | ||
include(GoogleTest) | ||
gtest_discover_tests(faiss_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# a few common functions for the tests | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import numpy as np | ||
import faiss | ||
|
||
# reduce number of threads to avoid excessive nb of threads in opt | ||
# mode (recuces runtime from 100s to 4s!) | ||
faiss.omp_set_num_threads(4) | ||
|
||
|
||
def random_unitary(n, d, seed): | ||
x = faiss.randn(n * d, seed).reshape(n, d) | ||
faiss.normalize_L2(x) | ||
return x | ||
|
||
|
||
class Randu10k: | ||
|
||
def __init__(self): | ||
self.nb = 10000 | ||
self.nq = 1000 | ||
self.nt = 10000 | ||
self.d = 128 | ||
|
||
self.xb = random_unitary(self.nb, self.d, 1) | ||
self.xt = random_unitary(self.nt, self.d, 2) | ||
self.xq = random_unitary(self.nq, self.d, 3) | ||
|
||
dotprods = np.dot(self.xq, self.xb.T) | ||
self.gt = dotprods.argmax(1) | ||
self.k = 100 | ||
|
||
def launch(self, name, index): | ||
if not index.is_trained: | ||
index.train(self.xt) | ||
index.add(self.xb) | ||
return index.search(self.xq, self.k) | ||
|
||
def evalres(self, DI): | ||
D, I = DI | ||
e = {} | ||
for rank in 1, 10, 100: | ||
e[rank] = ((I[:, :rank] == self.gt.reshape(-1, 1)).sum() / | ||
float(self.nq)) | ||
print("1-recalls: %s" % e) | ||
return e | ||
|
||
|
||
class Randu10kUnbalanced(Randu10k): | ||
|
||
def __init__(self): | ||
Randu10k.__init__(self) | ||
|
||
weights = 0.95 ** np.arange(self.d) | ||
rs = np.random.RandomState(123) | ||
weights = weights[rs.permutation(self.d)] | ||
self.xb *= weights | ||
self.xb /= np.linalg.norm(self.xb, axis=1)[:, np.newaxis] | ||
self.xq *= weights | ||
self.xq /= np.linalg.norm(self.xq, axis=1)[:, np.newaxis] | ||
self.xt *= weights | ||
self.xt /= np.linalg.norm(self.xt, axis=1)[:, np.newaxis] | ||
|
||
dotprods = np.dot(self.xq, self.xb.T) | ||
self.gt = dotprods.argmax(1) | ||
self.k = 100 | ||
|
||
|
||
def get_dataset(d, nb, nt, nq): | ||
rs = np.random.RandomState(123) | ||
xb = rs.rand(nb, d).astype('float32') | ||
xt = rs.rand(nt, d).astype('float32') | ||
xq = rs.rand(nq, d).astype('float32') | ||
|
||
return (xt, xb, xq) | ||
|
||
|
||
def get_dataset_2(d, nt, nb, nq): | ||
"""A dataset that is not completely random but still challenging to | ||
index | ||
""" | ||
d1 = 10 # intrinsic dimension (more or less) | ||
n = nb + nt + nq | ||
rs = np.random.RandomState(1338) | ||
x = rs.normal(size=(n, d1)) | ||
x = np.dot(x, rs.rand(d1, d)) | ||
# now we have a d1-dim ellipsoid in d-dimensional space | ||
# higher factor (>4) -> higher frequency -> less linear | ||
x = x * (rs.rand(d) * 4 + 0.1) | ||
x = np.sin(x) | ||
x = x.astype('float32') | ||
return x[:nt], x[nt:nt + nb], x[nt + nb:] | ||
|
||
|
||
def make_binary_dataset(d, nt, nb, nq): | ||
assert d % 8 == 0 | ||
rs = np.random.RandomState(123) | ||
x = rs.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8') | ||
return x[:nt], x[nt:-nq], x[-nq:] | ||
|
||
|
||
def compare_binary_result_lists(D1, I1, D2, I2): | ||
"""comparing result lists is difficult because there are many | ||
ties. Here we sort by (distance, index) pairs and ignore the largest | ||
distance of each result. Compatible result lists should pass this.""" | ||
assert D1.shape == I1.shape == D2.shape == I2.shape | ||
n, k = D1.shape | ||
ndiff = (D1 != D2).sum() | ||
assert ndiff == 0, '%d differences in distance matrix %s' % ( | ||
ndiff, D1.shape) | ||
|
||
def normalize_DI(D, I): | ||
norm = I.max() + 1.0 | ||
Dr = D.astype('float64') + I / norm | ||
# ignore -1s and elements on last column | ||
Dr[I1 == -1] = 1e20 | ||
Dr[D == D[:, -1:]] = 1e20 | ||
Dr.sort(axis=1) | ||
return Dr | ||
ndiff = (normalize_DI(D1, I1) != normalize_DI(D2, I2)).sum() | ||
assert ndiff == 0, '%d differences in normalized D matrix' % ndiff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import unittest | ||
import faiss | ||
|
||
|
||
class TestParameterSpace(unittest.TestCase): | ||
|
||
def test_nprobe(self): | ||
index = faiss.index_factory(32, "IVF32,Flat") | ||
ps = faiss.ParameterSpace() | ||
ps.set_index_parameter(index, "nprobe", 5) | ||
self.assertEqual(index.nprobe, 5) | ||
|
||
def test_nprobe_2(self): | ||
index = faiss.index_factory(32, "IDMap,IVF32,Flat") | ||
ps = faiss.ParameterSpace() | ||
ps.set_index_parameter(index, "nprobe", 5) | ||
index2 = faiss.downcast_index(index.index) | ||
self.assertEqual(index2.nprobe, 5) | ||
|
||
def test_nprobe_3(self): | ||
index = faiss.index_factory(32, "IVF32,SQ8,RFlat") | ||
ps = faiss.ParameterSpace() | ||
ps.set_index_parameter(index, "nprobe", 5) | ||
index2 = faiss.downcast_index(index.base_index) | ||
self.assertEqual(index2.nprobe, 5) | ||
|
||
def test_nprobe_4(self): | ||
index = faiss.index_factory(32, "PCAR32,IVF32,SQ8,RFlat") | ||
ps = faiss.ParameterSpace() | ||
|
||
ps.set_index_parameter(index, "nprobe", 5) | ||
index2 = faiss.downcast_index(index.base_index) | ||
index2 = faiss.downcast_index(index2.index) | ||
self.assertEqual(index2.nprobe, 5) | ||
|
||
def test_efSearch(self): | ||
index = faiss.index_factory(32, "IVF32_HNSW32,SQ8") | ||
ps = faiss.ParameterSpace() | ||
ps.set_index_parameter(index, "quantizer_efSearch", 5) | ||
index2 = faiss.downcast_index(index.quantizer) | ||
self.assertEqual(index2.hnsw.efSearch, 5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import unittest | ||
import faiss | ||
|
||
|
||
class TestBinaryFactory(unittest.TestCase): | ||
|
||
def test_factory_IVF(self): | ||
|
||
index = faiss.index_binary_factory(16, "BIVF10") | ||
assert index.invlists is not None | ||
assert index.nlist == 10 | ||
assert index.code_size == 2 | ||
|
||
def test_factory_Flat(self): | ||
|
||
index = faiss.index_binary_factory(16, "BFlat") | ||
assert index.code_size == 2 | ||
|
||
def test_factory_HNSW(self): | ||
|
||
index = faiss.index_binary_factory(256, "BHNSW32") | ||
assert index.code_size == 32 | ||
|
||
def test_factory_IVF_HNSW(self): | ||
|
||
index = faiss.index_binary_factory(256, "BIVF1024_BHNSW32") | ||
assert index.code_size == 32 | ||
assert index.nlist == 1024 | ||
|
||
def test_factory_Hash(self): | ||
index = faiss.index_binary_factory(256, "BHash12") | ||
assert index.b == 12 | ||
|
||
def test_factory_MultiHash(self): | ||
index = faiss.index_binary_factory(256, "BHash5x6") | ||
assert index.b == 6 | ||
assert index.nhash == 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/** | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <cstdio> | ||
#include <cstdlib> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <faiss/IndexBinaryFlat.h> | ||
#include <faiss/utils/hamming.h> | ||
|
||
TEST(BinaryFlat, accuracy) { | ||
// dimension of the vectors to index | ||
int d = 64; | ||
|
||
// size of the database we plan to index | ||
size_t nb = 1000; | ||
|
||
// make the index object and train it | ||
faiss::IndexBinaryFlat index(d); | ||
|
||
std::vector<uint8_t> database(nb * (d / 8)); | ||
for (size_t i = 0; i < nb * (d / 8); i++) { | ||
database[i] = rand() % 0x100; | ||
} | ||
|
||
{ // populating the database | ||
index.add(nb, database.data()); | ||
} | ||
|
||
size_t nq = 200; | ||
|
||
{ // searching the database | ||
|
||
std::vector<uint8_t> queries(nq * (d / 8)); | ||
for (size_t i = 0; i < nq * (d / 8); i++) { | ||
queries[i] = rand() % 0x100; | ||
} | ||
|
||
int k = 5; | ||
std::vector<faiss::idx_t> nns(k * nq); | ||
std::vector<int> dis(k * nq); | ||
|
||
index.search(nq, queries.data(), k, dis.data(), nns.data()); | ||
|
||
for (size_t i = 0; i < nq; ++i) { | ||
faiss::HammingComputer8 hc(queries.data() + i * (d / 8), d / 8); | ||
hamdis_t dist_min = hc.hamming(database.data()); | ||
for (size_t j = 1; j < nb; ++j) { | ||
hamdis_t dist = hc.hamming(database.data() + j * (d / 8)); | ||
if (dist < dist_min) { | ||
dist_min = dist; | ||
} | ||
} | ||
EXPECT_EQ(dist_min, dis[k * i]); | ||
} | ||
} | ||
} |
Oops, something went wrong.