Skip to content

Commit

Permalink
Use a KMP search algorithm for strings contains()
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwendt committed Nov 14, 2024
1 parent a7194f6 commit 3e15aac
Showing 1 changed file with 89 additions and 13 deletions.
102 changes: 89 additions & 13 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/find.hpp>
Expand Down Expand Up @@ -350,13 +351,13 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
string_view const d_target,
bool* d_results)
{
size_type const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);
using warp_reduce = cub::WarpReduce<bool>;
auto const idx = cudf::detail::grid_1d::global_thread_id();
using warp_reduce = cub::WarpReduce<bool>;
__shared__ typename warp_reduce::TempStorage temp_storage;

if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }
auto const str_idx = idx / cudf::detail::warp_size;
if (str_idx >= d_strings.size()) { return; }

auto const str_idx = idx / cudf::detail::warp_size;
auto const lane_idx = idx % cudf::detail::warp_size;
if (d_strings.is_null(str_idx)) { return; }
// get the string for this warp
Expand All @@ -369,7 +370,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
i += cudf::detail::warp_size * bytes_per_warp) {
// check the target matches this part of the d_str data
// this is definitely faster for very long strings > 128B
for (auto j = 0; j < bytes_per_warp; j++) {
for (auto j = 0; !found && (j < bytes_per_warp); j++) {
if (((i + j + d_target.size_bytes()) <= d_str.size_bytes()) &&
d_target.compare(d_str.data() + i + j, d_target.size_bytes()) == 0) {
found = true;
Expand All @@ -386,7 +387,6 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid.");
auto d_target = string_view(target.data(), target.size());

// create output column
Expand All @@ -410,7 +410,6 @@ std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
contains_warp_parallel_fn<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, d_target, results_view.data<bool>());
}
results->set_null_count(input.null_count());
return results;
}

Expand Down Expand Up @@ -534,24 +533,101 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
return results;
}

// number of possible values in a char type
constexpr std::size_t vals_in_char = 1 << CHAR_BIT;

/**
* @brief Initialize the KMP matrix for the given target
*
* This is launched as a single block with vals_in_char threads.
*
* @param d_target Target string to encode
* @param d_tomato KMP matrix output
*/
CUDF_KERNEL void init_kmp_tomato(string_view const d_target, size_type* d_tomato)
{
__shared__ int x;
auto tid = threadIdx.x;
if (tid == 0) { x = 0; }
d_tomato[tid] = 0; // init first row to all zeros
__syncthreads();
auto const tgt_u8 = reinterpret_cast<uint8_t const*>(d_target.data());
d_tomato[tgt_u8[0]] = 1;
for (int j = 1; j < d_target.size_bytes(); ++j) {
auto j_itr = d_tomato + (j * vals_in_char);
auto x_itr = d_tomato + (x * vals_in_char);
j_itr[tid] = x_itr[tid];
__syncthreads();
if (tid == 0) {
auto const curr_idx = tgt_u8[j];
j_itr[curr_idx] = j + 1;
// next
x = x_itr[curr_idx];
}
__syncthreads();
}
}

} // namespace

std::unique_ptr<column> contains(strings_column_view const& input,
string_scalar const& target,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
if (input.is_empty()) { return make_empty_column(type_id::BOOL8); }

CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid.");
// empty target string returns true;
// all null input returns all null output
if ((target.size() == 0) || (input.size() == input.null_count())) {
auto const true_scalar = make_fixed_width_scalar<bool>(true, stream);
auto results = make_column_from_scalar(*true_scalar, input.size(), stream, mr);
results->set_null_mask(cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count());
return results;
}

// use warp parallel when the average string width is greater than the threshold
if ((input.null_count() < input.size()) &&
((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
((input.chars_size(stream) / (input.size() - input.null_count())) >
AVG_CHAR_BYTES_THRESHOLD)) {
return contains_warp_parallel(input, target, stream, mr);
}

// benchmark measurements showed this to be faster for smaller strings
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
};
return contains_fn(input, target, pfn, stream, mr);
// use KMP algorithm for thread-per-row implementation
auto d_target = string_view(target.data(), target.size());
auto d_strings = column_device_view::create(input.parent(), stream);
// create output column
auto results = make_numeric_column(data_type{type_id::BOOL8},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
auto d_results = results->mutable_view().data<bool>();

auto tomato = rmm::device_uvector<size_type>(target.size() * vals_in_char, stream);
auto d_tomato = tomato.data();
init_kmp_tomato<<<1, vals_in_char>>>(d_target, d_tomato);

thrust::transform(
rmm::exec_policy(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(input.size()),
d_results,
[d_strings = *d_strings, target_size = target.size(), d_tomato] __device__(size_type idx) {
if (d_strings.is_null(idx)) { return false; }
auto const d_str = d_strings.element<string_view>(idx);
size_type result = 0;
for (size_type i = 0; i < d_str.size_bytes() && result < target_size; ++i) {
auto curr_idx = static_cast<uint8_t>(d_str.data()[i]);
result = d_tomato[result * vals_in_char + curr_idx];
}
return result == target_size;
});

return results;
}

std::unique_ptr<column> contains(strings_column_view const& strings,
Expand Down

0 comments on commit 3e15aac

Please sign in to comment.