Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marksg07 committed Jan 10, 2025
1 parent 4bcba8a commit 3e0301e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/mc-str-encode-string-sets-private.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
#ifndef MONGOCRYPT_STR_ENCODE_STRING_SETS_PRIVATE_H
#define MONGOCRYPT_STR_ENCODE_STRING_SETS_PRIVATE_H

#include "mongocrypt-buffer-private.h"
#include "mongocrypt.h"

// Represents a valid unicode string with the bad character 0xFF appended to the end. This is our base string which
// we build substring trees on. Stores all the valid code points in the string, plus one code point for 0xFF.
// Exposed for testing.
typedef struct {
char *data;
uint32_t len;
_mongocrypt_buffer_t buf;
uint32_t *codepoint_offsets;
uint32_t codepoint_len;
} mc_utf8_string_with_bad_char_t;
Expand Down
24 changes: 12 additions & 12 deletions src/mc-str-encode-string-sets.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "mc-str-encode-string-sets-private.h"
#include "mongocrypt-buffer-private.h"
#include <bson/bson.h>
#include <stdint.h>

Expand All @@ -23,10 +24,9 @@
// Input must be pre-validated by bson_utf8_validate().
mc_utf8_string_with_bad_char_t *mc_utf8_string_with_bad_char_from_buffer(const char *buf, uint32_t len) {
mc_utf8_string_with_bad_char_t *ret = bson_malloc0(sizeof(mc_utf8_string_with_bad_char_t));
ret->data = bson_malloc0(len + 1);
ret->len = len + 1;
memcpy(ret->data, buf, len);
ret->data[len] = BAD_CHAR;
_mongocrypt_buffer_init_size(&ret->buf, len + 1);
memcpy(ret->buf.data, buf, len);
ret->buf.data[len] = BAD_CHAR;
// max # offsets is the total length
ret->codepoint_offsets = bson_malloc0(sizeof(uint32_t) * (len + 1));
const char *cur = buf;
Expand All @@ -48,7 +48,7 @@ void mc_utf8_string_with_bad_char_destroy(mc_utf8_string_with_bad_char_t *utf8)
return;
}
bson_free(utf8->codepoint_offsets);
bson_free(utf8->data);
_mongocrypt_buffer_cleanup(&utf8->buf);
bson_free(utf8);
}

Expand Down Expand Up @@ -121,11 +121,11 @@ bool mc_affix_set_iter_next(mc_affix_set_iter_t *it, const char **str, uint32_t
uint32_t end_idx = it->set->end_indices[idx];
uint32_t start_byte_offset = it->set->base_string->codepoint_offsets[start_idx];
// Pointing to the end of the codepoints represents the end of the string.
uint32_t end_byte_offset = it->set->base_string->len;
uint32_t end_byte_offset = it->set->base_string->buf.len;
if (end_idx != it->set->base_string->codepoint_len) {
end_byte_offset = it->set->base_string->codepoint_offsets[end_idx];
}
*str = &it->set->base_string->data[start_byte_offset];
*str = (const char *)it->set->base_string->buf.data + start_byte_offset;
*len = end_byte_offset - start_byte_offset;
*count = it->set->substring_counts[idx];
return true;
Expand Down Expand Up @@ -206,7 +206,7 @@ bool mc_substring_set_insert(mc_substring_set_t *set, uint32_t base_start_idx, u
return false;
}
uint32_t start_byte_offset = set->base_string->codepoint_offsets[base_start_idx];
const char *start = set->base_string->data + start_byte_offset;
const char *start = (const char *)set->base_string->buf.data + start_byte_offset;
uint32_t len = set->base_string->codepoint_offsets[base_end_idx] - start_byte_offset;
uint32_t hash = fnv1a(start, len);
uint32_t idx = hash % HASHSET_SIZE;
Expand All @@ -216,7 +216,7 @@ bool mc_substring_set_insert(mc_substring_set_t *set, uint32_t base_start_idx, u
mc_substring_set_node_t *prev;
while (node) {
prev = node;
if (len == node->len && memcmp(start, set->base_string->data + node->start_offset, len) == 0) {
if (len == node->len && memcmp(start, set->base_string->buf.data + node->start_offset, len) == 0) {
// Match, no insertion
return false;
}
Expand Down Expand Up @@ -252,8 +252,8 @@ bool mc_substring_set_iter_next(mc_substring_set_iter_t *it, const char **str, u
// Almost done with iteration; return base string if count is not 0.
if (it->set->base_string_count) {
*count = it->set->base_string_count;
*str = it->set->base_string->data;
*len = it->set->base_string->len;
*str = (const char *)it->set->base_string->buf.data;
*len = it->set->base_string->buf.len;
return true;
}
return false;
Expand All @@ -264,7 +264,7 @@ bool mc_substring_set_iter_next(mc_substring_set_iter_t *it, const char **str, u
mc_substring_set_node_t *cur = (mc_substring_set_node_t *)(it->cur_node);
// Count is always 1 for substrings in the hashset
*count = 1;
*str = &it->set->base_string->data[cur->start_offset];
*str = (const char *)it->set->base_string->buf.data + cur->start_offset;
*len = cur->len;
it->cur_node = (void *)cur->next;
return true;
Expand Down
3 changes: 1 addition & 2 deletions src/mc-text-search-str-encode-private.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ typedef struct {
// Set of encoded substrings.
mc_substring_set_t *substring_set;
// Encoded exact string.
char *exact;
size_t exact_len;
_mongocrypt_buffer_t exact;
} mc_str_encode_sets_t;

// Run StrEncode with the given spec.
Expand Down
37 changes: 35 additions & 2 deletions src/mc-text-search-str-encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mc-str-encode-string-sets-private.h"
#include "mc-text-search-str-encode-private.h"
#include "mongocrypt-buffer-private.h"
#include "mongocrypt.h"
#include <bson/bson.h>
#include <stdint.h>
Expand Down Expand Up @@ -94,6 +95,32 @@ static mc_substring_set_t *generate_substring_tree(const mc_utf8_string_with_bad
// No valid substrings, return empty tree
return NULL;
}

// If you are following along with the OST paper, a slightly different calculation of msize is used. The following
// justifies why that calculation and this calculation are equivalent.
// At this point, it is established that:
// beta <= mlen
// lb <= cbclen
// lb <= ub <= mlen
//
// So, the following formula for msize in the OST paper:
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1))
// maxkgram_2 = sum_(j=lb, min(ub, cbclen), (cbclen - j + 1))
// msize = min(maxkgram_1, maxkgram_2)
// can be simplified to:
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))
//
// because if cbclen <= ub, then it follows that cbclen <= ub <= mlen, and so
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
// maxkgram_2 = sum_(j=lb, cbclen, (cbclen - j + 1)) # less or equal to maxkgram_1
// msize = maxkgram_2
// and if cbclen > ub, then it follows that:
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
// maxkgram_2 = sum_(j=lb, ub, (cbclen - j + 1)) # same sum bounds as maxkgram_1
// msize = sum_(j=lb, ub, (min(mlen, cbclen) - j + 1))
// in both cases, msize can be rewritten as:
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))

uint32_t folded_codepoint_len = base_str->codepoint_len - 1;
// If mlen < cbclen, we only need to pad to mlen
uint32_t padded_len = BSON_MIN(spec->mlen, cbclen);
Expand Down Expand Up @@ -155,11 +182,17 @@ mc_str_encode_sets_t *mc_text_search_str_encode_helper(const mc_FLE2TextSearchIn
sets->prefix_set = generate_prefix_tree(sets->base_string, unfolded_codepoint_len, &spec->prefix.value);
}
if (spec->substr.set) {
if (unfolded_codepoint_len > spec->substr.value.mlen) {
CLIENT_ERR("StrEncode: String passed in was longer than the maximum length for substring indexing -- "
"String len: %u, max len: %u",
unfolded_codepoint_len,
spec->substr.value.mlen);
return NULL;
}
sets->substring_set = generate_substring_tree(sets->base_string, unfolded_codepoint_len, &spec->substr.value);
}
// Exact string is always the first len characters of the base string
sets->exact = sets->base_string->data;
sets->exact_len = folded_str_bytes_len;
_mongocrypt_buffer_from_data(&sets->exact, sets->base_string->buf.data, folded_str_bytes_len);
return sets;
}

Expand Down
5 changes: 5 additions & 0 deletions src/mongocrypt-buffer-private.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ bool _mongocrypt_buffer_steal_from_string(_mongocrypt_buffer_t *buf, char *str)
* - Caller must call _mongocrypt_buffer_cleanup. */
bool _mongocrypt_buffer_from_string(_mongocrypt_buffer_t *buf, const char *str) MONGOCRYPT_WARN_UNUSED_RESULT;

/* _mongocrypt_buffer_from_ initializes @buf from @data with length @len.
* @buf retains a pointer to @data.
* @data must outlive @buf. */
void _mongocrypt_buffer_from_data(_mongocrypt_buffer_t *buf, const uint8_t *data, uint32_t len);

/* _mongocrypt_buffer_copy_from_uint64_le initializes @buf from the
* little-endian byte representation of @value. Caller must call
* _mongocrypt_buffer_cleanup.
Expand Down
10 changes: 10 additions & 0 deletions src/mongocrypt-buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,16 @@ bool _mongocrypt_buffer_from_string(_mongocrypt_buffer_t *buf, const char *str)
return true;
}

void _mongocrypt_buffer_from_data(_mongocrypt_buffer_t *buf, const uint8_t *data, uint32_t len) {
BSON_ASSERT_PARAM(buf);
BSON_ASSERT_PARAM(data);

_mongocrypt_buffer_init(buf);
buf->data = (uint8_t *)data;
buf->len = len;
buf->owned = false;
}

void _mongocrypt_buffer_copy_from_uint64_le(_mongocrypt_buffer_t *buf, uint64_t value) {
uint64_t value_le = MONGOCRYPT_UINT64_TO_LE(value);

Expand Down
58 changes: 28 additions & 30 deletions test/test-mc-text-search-str-encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
sets = mc_text_search_str_encode_helper(&spec, unfolded_codepoint_len, status);
}
ASSERT_OR_PRINT(sets, status);
ASSERT(sets->base_string->len == byte_len + 1);
ASSERT(sets->base_string->buf.len == byte_len + 1);
ASSERT(sets->base_string->codepoint_len == codepoint_len + 1);
ASSERT(0 == memcmp(sets->base_string->data, str, byte_len));
ASSERT(sets->base_string->data[byte_len] == (char)0xFF);
ASSERT(0 == memcmp(sets->base_string->buf.data, str, byte_len));
ASSERT(sets->base_string->buf.data[byte_len] == (uint8_t)0xFF);
ASSERT(sets->substring_set == NULL);
ASSERT(sets->exact_len == byte_len);
ASSERT(0 == memcmp(sets->exact, str, byte_len));
ASSERT(sets->exact.len == byte_len);
ASSERT(0 == memcmp(sets->exact.data, str, byte_len));

if (lb > max_padded_len) {
ASSERT(sets->suffix_set == NULL);
Expand Down Expand Up @@ -114,8 +114,8 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
// indices.
fprintf(stderr,
"Affix starting %lld, ending %lld, count %u\n",
(long long)(affix - sets->base_string->data),
(long long)(affix - sets->base_string->data + affix_len),
(long long)((uint8_t *)affix - sets->base_string->buf.data),
(long long)((uint8_t *)affix - sets->base_string->buf.data + affix_len),
affix_count);
if (affix_len == byte_len + 1) {
// This is padding, so there should be no more entries due to how we ordered them
Expand All @@ -130,11 +130,11 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
// slightly easier when testing.
if (suffix) {
uint32_t start_offset = sets->base_string->codepoint_offsets[codepoint_len - (lb + idx)];
ASSERT(affix == sets->base_string->data + start_offset);
ASSERT((uint8_t *)affix == sets->base_string->buf.data + start_offset);
ASSERT(affix_len == sets->base_string->codepoint_offsets[codepoint_len] - start_offset)
} else {
uint32_t end_offset = sets->base_string->codepoint_offsets[lb + idx];
ASSERT(affix == sets->base_string->data);
ASSERT((uint8_t *)affix == sets->base_string->buf.data);
ASSERT(affix_len == end_offset);
}
// The count should always be 1, except for padding.
Expand All @@ -145,7 +145,7 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
ASSERT(total_real_affix_count == n_real_affixes);
if (affix_len == byte_len + 1) {
// Padding
ASSERT(affix == sets->base_string->data);
ASSERT((uint8_t *)affix == sets->base_string->buf.data);
ASSERT(affix_count == n_padding);
} else {
// No padding found
Expand Down Expand Up @@ -192,7 +192,8 @@ static uint32_t calc_unique_substrings(const mc_utf8_string_with_bad_char_t *str
uint32_t j_start_byte = str->codepoint_offsets[j];
uint32_t j_end_byte = str->codepoint_offsets[j + ss_len];
if (i_end_byte - i_start_byte == j_end_byte - j_start_byte
&& memcmp(&str->data[i_start_byte], &str->data[j_start_byte], i_end_byte - i_start_byte) == 0) {
&& memcmp(&str->buf.data[i_start_byte], &str->buf.data[j_start_byte], i_end_byte - i_start_byte)
== 0) {
idx_is_dupe[j] = 1;
dupes++;
}
Expand Down Expand Up @@ -226,17 +227,21 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
mc_FLE2TextSearchInsertSpec_t spec =
{str, byte_len, {{mlen, lb, ub}, true}, {{0, 0}, false}, {{0, 0}, false}, false, false};
sets = mc_text_search_str_encode_helper(&spec, unfolded_codepoint_len, status);

if (unfolded_codepoint_len > mlen) {
ASSERT_FAILS_STATUS(sets, status, "longer than the maximum length");
mongocrypt_status_destroy(status);
return;
}
ASSERT_OR_PRINT(sets, status);
mongocrypt_status_destroy(status);
ASSERT(sets->base_string->len == byte_len + 1);
ASSERT(sets->base_string->buf.len == byte_len + 1);
ASSERT(sets->base_string->codepoint_len == codepoint_len + 1);
ASSERT(0 == memcmp(sets->base_string->data, str, byte_len));
ASSERT(sets->base_string->data[byte_len] == (char)0xFF);
ASSERT(0 == memcmp(sets->base_string->buf.data, str, byte_len));
ASSERT(sets->base_string->buf.data[byte_len] == (uint8_t)0xFF);
ASSERT(sets->suffix_set == NULL)
ASSERT(sets->prefix_set == NULL);
ASSERT(sets->exact_len == byte_len);
ASSERT(0 == memcmp(sets->exact, str, byte_len));
ASSERT(sets->exact.len == byte_len);
ASSERT(0 == memcmp(sets->exact.data, str, byte_len));

if (unfolded_codepoint_len > mlen || lb > max_padded_len) {
ASSERT(sets->substring_set == NULL);
Expand All @@ -258,18 +263,15 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
mc_substring_set_iter_t it;
mc_substring_set_iter_init(&it, set);
const char *substring;
// 2D array: counts[i + j*len] is the number of substrings returned which started at byte i
// and ended at byte j (inclusive) of the base string.
uint32_t *counts = calloc(byte_len * byte_len, sizeof(uint32_t));

uint32_t substring_len = 0;
uint32_t substring_count = 0;
uint32_t total_real_substring_count = 0;
while (mc_substring_set_iter_next(&it, &substring, &substring_len, &substring_count)) {
fprintf(stderr,
"Substring starting %lld, ending %lld, count %u: \"%.*s\"\n",
(long long)(substring - sets->base_string->data),
(long long)(substring - sets->base_string->data + substring_len),
(long long)((uint8_t *)substring - sets->base_string->buf.data),
(long long)((uint8_t *)substring - sets->base_string->buf.data + substring_len),
substring_count,
substring_len,
substring);
Expand All @@ -279,25 +281,21 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
break;
}

ASSERT(substring + substring_len <= sets->base_string->data + byte_len);
ASSERT((uint8_t *)substring + substring_len <= sets->base_string->buf.data + byte_len);
ASSERT(substring_len <= byte_len);
ASSERT(0 < substring_len);
ASSERT(1 == substring_count);
total_real_substring_count++;
uint32_t start_offset = (uint32_t)(substring - sets->base_string->data);

counts[start_offset + (start_offset + substring_len - 1) * byte_len]++;
}
ASSERT(total_real_substring_count == n_real_substrings);
if (substring_len == byte_len + 1) {
// Padding
ASSERT(substring == sets->base_string->data);
ASSERT((uint8_t *)substring == sets->base_string->buf.data);
ASSERT(substring_count == n_padding);
} else {
// No padding found
ASSERT(n_padding == 0)
}
free(counts);
mc_str_encode_sets_destroy(sets);
}

Expand Down Expand Up @@ -582,8 +580,8 @@ static void _test_text_search_str_encode_multiple(_mongocrypt_tester_t *tester)
ASSERT(0 == memcmp("123456789", str, 9));
ASSERT(count == 1);

ASSERT(sets->exact_len == 9);
ASSERT(0 == memcmp(sets->exact, str, 9));
ASSERT(sets->exact.len == 9);
ASSERT(0 == memcmp(sets->exact.data, str, 9));

mc_str_encode_sets_destroy(sets);
}
Expand Down

0 comments on commit 3e0301e

Please sign in to comment.