Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add bf16 support for adding to the index #558

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,13 @@ void test_minimal_three_vectors(index_at& index, //
// Search again over reconstructed index
{
matched_count = index.search(vector_first.data(), 5, args...).dump_to(matched_keys, matched_distances);

std::printf("matched_count: %zu \n", matched_count);
expect_eq(matched_count, 3);
std::printf("matched_keys[0]: %zu \n", matched_keys[0]);
std::printf("key_first: %zu \n", key_first);
expect_eq(matched_keys[0], key_first);
std::printf("matched_distances[0]: %f \n", matched_distances[0]);
expect(std::abs(matched_distances[0]) < 0.01);
}

Expand Down Expand Up @@ -1100,6 +1105,9 @@ int main(int, char**) {
test_uint40();
test_cosine<float, std::int64_t, uint40_t>(10, 10);

// Test for bf16 scalar elements type
test_cosine<bf16_t, std::int64_t, uint40_t>(10, 10);

// Test plugins, like K-Means clustering.
{
std::size_t vectors_count = 1000, centroids_count = 10, dimensions = 256;
Expand Down
6 changes: 6 additions & 0 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,36 +760,42 @@ class index_dense_gt {
add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.b1x8); }
add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.i8); }
add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f16); }
add_result_t add(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.bf16); }
add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f32); }
add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f64); }

search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.b1x8); }
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.i8); }
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f16); }
search_result_t search(bf16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.bf16); }
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f32); }
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f64); }

template <typename predicate_at> search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.b1x8); }
template <typename predicate_at> search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.i8); }
template <typename predicate_at> search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f16); }
template <typename predicate_at> search_result_t filtered_search(bf16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.bf16); }
template <typename predicate_at> search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f32); }
template <typename predicate_at> search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f64); }

std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.b1x8); }
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.i8); }
std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f16); }
std::size_t get(vector_key_t key, bf16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.bf16); }
std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f32); }
std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f64); }

cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.b1x8); }
cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.i8); }
cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f16); }
cluster_result_t cluster(bf16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.bf16); }
cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f32); }
cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f64); }

aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.b1x8); }
aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.i8); }
aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f16); }
aggregated_distances_t distance_between(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.bf16); }
aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f32); }
aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f64); }
// clang-format on
Expand Down
10 changes: 9 additions & 1 deletion include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ class bf16_bits_t {
uint16_ = f32_to_bf16(v / bf16_to_f32(uint16_));
return *this;
}

inline bf16_bits_t& operator=(int v) noexcept {
uint16_ = f32_to_bf16(static_cast<float>(v));
return *this;
}
};

/**
Expand Down Expand Up @@ -1223,6 +1228,7 @@ struct casts_punned_t {
cast_punned_t b1x8{};
cast_punned_t i8{};
cast_punned_t f16{};
cast_punned_t bf16{};
cast_punned_t f32{};
cast_punned_t f64{};

Expand All @@ -1231,7 +1237,7 @@ struct casts_punned_t {
case scalar_kind_t::f64_k: return f64;
case scalar_kind_t::f32_k: return f32;
case scalar_kind_t::f16_k: return f16;
case scalar_kind_t::bf16_k: return f16;
case scalar_kind_t::bf16_k: return bf16;
case scalar_kind_t::i8_k: return i8;
case scalar_kind_t::b1x8_k: return b1x8;
default: return nullptr;
Expand All @@ -1246,12 +1252,14 @@ struct casts_punned_t {
result.from.b1x8 = &cast_gt<b1x8_t, scalar_at>::try_;
result.from.i8 = &cast_gt<i8_t, scalar_at>::try_;
result.from.f16 = &cast_gt<f16_t, scalar_at>::try_;
result.from.bf16 = &cast_gt<bf16_t, scalar_at>::try_;
result.from.f32 = &cast_gt<f32_t, scalar_at>::try_;
result.from.f64 = &cast_gt<f64_t, scalar_at>::try_;

result.to.b1x8 = &cast_gt<scalar_at, b1x8_t>::try_;
result.to.i8 = &cast_gt<scalar_at, i8_t>::try_;
result.to.f16 = &cast_gt<scalar_at, f16_t>::try_;
result.to.bf16 = &cast_gt<scalar_at, bf16_t>::try_;
result.to.f32 = &cast_gt<scalar_at, f32_t>::try_;
result.to.f64 = &cast_gt<scalar_at, f64_t>::try_;

Expand Down