-
Notifications
You must be signed in to change notification settings - Fork 926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] Improved performance for strings finder_warp_parallel_fn / contains_warp_parallel_fn kernels #15405
Comments
This is definitely more in @davidwendt's wheelhouse than in mine. I'm trying to familiarize myself with |
I'll explore the data and post more info here. I'm still looking at the call stack, etc. There are certainly wins to be had by switching the query from using The user query is of the form: CASE WHEN instr(lower(my_str), 'my_sub_str') > 0 THEN ... I'd like to check the feasibility of translating |
I have done some exploration of the data in question, and the query.
I'm not sure why/how I think @revans2's changes have seen to it that I haven't run a profile on the sample yet. I'll try get that going tomorrow. |
I'm working on a block-parallel version of |
As an aside, I should mention that the data distributions I mentioned above can be ignored, for the moment. The sample is not representative of the user's data. |
I have a naive block-parallel implementation here. This change switches to block-parallel if the average string length reaches 256 or 512. (I've tried both.) Here are some results from running It appears that benefits aren't apparent unless the average string sizes reach around 4K-8K. And that gets slightly worse at higher row counts:
|
From exploring the customer's data, it appears that the majority of the search input strings are under 256 bytes long, although there are outliers (some being 64K long). The average length amounts to 145. I don't think going wider than 256 threads per block is reasonable. I've gotten some I've also run For block-parallel:
For warp-per-string:
I wonder if I might be missing a trick here, with the block-parallel implementation. cc @davidwendt, @nvdbaranec, @hyperbolic2346, whom I've been consulting on this. |
I've generated a local dataset with the search key distributions in much the same way as that of the reported slow case. (This includes the order of if-else clauses, with a similar match rate.) At 4M rows, with an average string length of 256, with the search keys of 12-char average lengths, the total runtimes are a near match. It's not looking like the kernel runtimes have an appreciable effect on the total runtime. If there's anything afflicting the NSight Compute analysis did seem to indicate the following warning regarding
I'm trying to understand what can be changed here, but I'm wondering if we should be considering an algorithmic fix:
|
P.S. I think I left the impression that the More correctly, the profiles of the Spark tasks indicate that the Even a small improvement to the kernel is likely to amplify, at that scale. |
The first approach (processing 1 string/threadblock instead of 1 string/warp) was a bust. At the user's average string size of 144 characters, it appeared that too many threads in the block had too little to do. The second approach (processing N strings in the same kernel, instead of running the single "contains_warp_parallel_fn" kernel N times) should have reduced the processing time (by amortizing the costs of kernel launch). This seemed like a bust. Tests at the user site indicated that this was taking slightly longer as well. The current thought (hat tip to @nvdbaranec) is that this might have something to do with null string inputs. It's possible that GPU occupancy reduces when there are more null input rows than non-null. The null-threads exit early, and wait for the completion of non-null threads. |
The fastest execution time to find 5 sub-strings across 1M input rows for a variety of null distributions seems to be to call the It appears that the null-row theory isn't completely accurate. :/ |
Reference #15405 Updates the benchmarks for `cudf::strings::contains()` to use nvbench and also introduce a hit-test axis. The logic has been updated to remove the unneeded `fill()` call for long strings. Also cleaned up code and updated logic to process 4 bytes per warp thread. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Yunsong Wang (https://github.com/PointKernel) - Nghia Truong (https://github.com/ttnghia) URL: #15495
Is your feature request related to a problem? Please describe.
A customer has a query that performs many string find/contains operations, often on long strings. Nsight traces show most of the GPU time is being spent in finder_warp_parallel_fn or contains_warp_parallel_fn, significantly more than Parquet decompress and decode which are typically the top GPU kernels.
Describe the solution you'd like
Improved performance for these kernels.
The text was updated successfully, but these errors were encountered: