-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f6bd737
commit 98c682c
Showing
14 changed files
with
3,129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
# The contrib modules | ||
|
||
The contrib directory contains helper modules for Faiss for various tasks. | ||
|
||
## Code structure | ||
|
||
The contrib directory gets compiled in the module faiss.contrib. | ||
Note that although some of the modules may depend on additional modules (eg. GPU Faiss, pytorch, hdf5), they are not necessarily compiled in to avoid adding dependencies. It is the user's responsibility to provide them. | ||
|
||
In contrib, we are progressively dropping python2 support. | ||
|
||
## List of contrib modules | ||
|
||
### rpc.py | ||
|
||
A very simple Remote Procedure Call library, where function parameters and results are pickled, for use with client_server.py | ||
|
||
### client_server.py | ||
|
||
The server handles requests to a Faiss index. The client calls the remote index. | ||
This is mainly to shard datasets over several machines, see [Distributed index](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#distributed-index) | ||
|
||
### ondisk.py | ||
|
||
Encloses the main logic to merge indexes into an on-disk index. | ||
See [On-disk storage](https://github.com/facebookresearch/faiss/wiki/Indexes-that-do-not-fit-in-RAM#on-disk-storage) | ||
|
||
### exhaustive_search.py | ||
|
||
Computes the ground-truth search results for a dataset that possibly does not fit in RAM. Uses GPU if available. | ||
Tested in `tests/test_contrib.TestComputeGT` | ||
|
||
### torch_utils.py | ||
|
||
Interoperability functions for pytorch and Faiss: Importing this will allow pytorch Tensors (CPU or GPU) to be used as arguments to Faiss indexes and other functions. Torch GPU tensors can only be used with Faiss GPU indexes. If this is imported with a package that supports Faiss GPU, the necessary stream synchronization with the current pytorch stream will be automatically performed. | ||
|
||
Numpy ndarrays can continue to be used in the Faiss python interface after importing this file. All arguments must be uniformly either numpy ndarrays or Torch tensors; no mixing is allowed. | ||
|
||
Tested in `tests/test_contrib_torch.py` (CPU) and `gpu/test/test_contrib_torch_gpu.py` (GPU). | ||
|
||
### inspect_tools.py | ||
|
||
Functions to inspect C++ objects wrapped by SWIG. Most often this just means reading | ||
fields and converting them to the proper python array. | ||
|
||
### ivf_tools.py | ||
|
||
A few functions to override the coarse quantizer in IVF, providing additional flexibility for assignment. | ||
|
||
### datasets.py | ||
|
||
(may require h5py) | ||
|
||
Definition of how to access data for some standard datasets. | ||
|
||
### factory_tools.py | ||
|
||
Functions related to factory strings. | ||
|
||
### evaluation.py | ||
|
||
A few non-trivial evaluation functions for search results | ||
|
||
### clustering.py | ||
|
||
Contains: | ||
|
||
- a Python implementation of kmeans, that can be used for special datatypes (eg. sparse matrices). | ||
|
||
- a 2-level clustering routine and a function that can apply it to train an IndexIVF |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from multiprocessing.pool import ThreadPool | ||
import faiss | ||
from typing import List, Tuple | ||
|
||
from . import rpc | ||
|
||
############################################################ | ||
# Server implementation | ||
############################################################ | ||
|
||
|
||
class SearchServer(rpc.Server): | ||
""" Assign version that can be exposed via RPC """ | ||
|
||
def __init__(self, s: int, index: faiss.Index): | ||
rpc.Server.__init__(self, s) | ||
self.index = index | ||
self.index_ivf = faiss.extract_index_ivf(index) | ||
|
||
def set_nprobe(self, nprobe: int) -> int: | ||
""" set nprobe field """ | ||
self.index_ivf.nprobe = nprobe | ||
|
||
def get_ntotal(self) -> int: | ||
return self.index.ntotal | ||
|
||
def __getattr__(self, f): | ||
# all other functions get forwarded to the index | ||
return getattr(self.index, f) | ||
|
||
|
||
def run_index_server(index: faiss.Index, port: int, v6: bool = False): | ||
""" serve requests for that index forerver """ | ||
rpc.run_server( | ||
lambda s: SearchServer(s, index), | ||
port, v6=v6) | ||
|
||
|
||
############################################################ | ||
# Client implementation | ||
############################################################ | ||
|
||
class ClientIndex: | ||
"""manages a set of distance sub-indexes. The sub_indexes search a | ||
subset of the inverted lists. Searches are merged afterwards | ||
""" | ||
|
||
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False): | ||
""" connect to a series of (host, port) pairs """ | ||
self.sub_indexes = [] | ||
for machine, port in machine_ports: | ||
self.sub_indexes.append(rpc.Client(machine, port, v6)) | ||
|
||
self.ni = len(self.sub_indexes) | ||
# pool of threads. Each thread manages one sub-index. | ||
self.pool = ThreadPool(self.ni) | ||
# test connection... | ||
self.ntotal = self.get_ntotal() | ||
self.verbose = False | ||
|
||
def set_nprobe(self, nprobe: int) -> None: | ||
self.pool.map( | ||
lambda idx: idx.set_nprobe(nprobe), | ||
self.sub_indexes | ||
) | ||
|
||
def set_omp_num_threads(self, nt: int) -> None: | ||
self.pool.map( | ||
lambda idx: idx.set_omp_num_threads(nt), | ||
self.sub_indexes | ||
) | ||
|
||
def get_ntotal(self) -> None: | ||
return sum(self.pool.map( | ||
lambda idx: idx.get_ntotal(), | ||
self.sub_indexes | ||
)) | ||
|
||
def search(self, x, k: int): | ||
|
||
rh = faiss.ResultHeap(x.shape[0], k) | ||
|
||
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes): | ||
rh.add_result(Di, Ii) | ||
rh.finalize() | ||
return rh.D, rh.I |
Oops, something went wrong.