diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index c83e1f48972..a4d7a162a90 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -38,9 +38,12 @@ namespace cugraph { * we can find the minimum (hop, flag) pairs for every unique vertex ID (hop is the primary key and * flag is the secondary key, flag=major is considered smaller than flag=minor if hop numbers are * same). Vertex IDs with smaller (hop, flag) pairs precede vertex IDs with larger (hop, flag) pairs - * in renumbering. Ordering can be arbitrary among the vertices with the same (hop, flag) pairs. + * in renumbering. Ordering can be arbitrary among the vertices with the same (hop, flag) pairs. If + * @p seed_vertices.has_value() is true, we assume (hop=0, flag=major) for every vertex in @p + * *seed_vertices in renumbering (this is relevant when there are seed vertices with no neighbors). * 2. If @p edgelist_hops is invalid, unique vertex IDs in edge majors precede vertex IDs that - * appear only in edge minors. + * appear only in edge minors. If @p seed_vertices.has_value() is true, vertices in @p + * *seed_vertices precede vertex IDs that appear only in edge minors as well. * 3. If edgelist_label_offsets.has_value() is true, edge lists for different labels will be * renumbered separately. * @@ -54,9 +57,10 @@ namespace cugraph { * (if @p src_is_major is true) or DCSC (if @p src_is_major is false). If @p doubly_compress is * false, the CSR/CSC offset array size is the number of vertices (which is the maximum vertex ID + * 1) + 1. Here, the maximum vertex ID is the maximum major vertex ID in the edges to compress if @p - * compress_per_hop is false or for hop 0. If @p compress_per_hop is true and hop number is 1 or - * larger, the maximum vertex ID is the larger of the maximum major vertex ID for this hop and the - * maximum vertex ID for the edges in the previous hops. + * compress_per_hop is false or for hop 0 (@p seed_vertices should be included if valid). If @p + * compress_per_hop is true and hop number is 1 or larger, the maximum vertex ID is the larger of + * the maximum major vertex ID for this hop and the maximum vertex ID for the edges in the previous + * hops. * * If both @p compress_per_hop is false and @p edgelist_hops.has_value() is true, majors should be * non-decreasing within each label after renumbering and sorting by (hop, major, minor). Also, @@ -82,11 +86,19 @@ namespace cugraph { * edgelist_srcs.size() if valid). * @param edgelist_edge_types An optional vector storing edgelist edge types (size = @p * edgelist_srcs.size() if valid). - * @param edgelist_hops An optional tuple having a vector storing edge list hop numbers (size = @p - * edgelist_srcs.size() if valid) and the number of hops. - * @param edgelist_label_offsets An optional tuple storing a pointer to the array storing label - * offsets to the input edges (size = std::get<1>(*edgelist_label_offsets) + 1) and the number of - * labels. + * @param edgelist_hops An optional vector storing edge list hop numbers (size = @p + * edgelist_srcs.size() if valid). @p edgelist_hops should be valid if @p num_hops >= 2. + * @param seed_vertices An optional pointer to the array storing seed vertices in hop 0. + * @param seed_vertex_label_offsets An optional pointer to the array storing label offsets to the + * seed vertices (size = @p num_labels + 1). @p seed_vertex_label_offsets should be valid if @p + * num_labels >= 2 and @p seed_vertices is valid and invalid otherwise. + * @param edgelist_label_offsets An optional pointer to the array storing label offsets to the input + * edges (size = @p num_labels + 1). @p edgelist_label_offsets should be valid if @p num_labels + * >= 2. + * @param num_labels Number of labels. Labels are considered if @p num_labels >=2 and ignored if @p + * num_labels = 1. + * @param num_hops Number of hops. Hop numbers are considered if @p num_hops >=2 and ignored if @p + * num_hops = 1. * @param src_is_major A flag to determine whether to use the source or destination as the * major key in renumbering and compression. * @param compress_per_hop A flag to determine whether to compress edges with different hop numbers @@ -100,13 +112,10 @@ namespace cugraph { * edgelist_weights.has_value() is true), optional edge IDs (valid only if @p * edgelist_edge_ids.has_value() is true), optional edge types (valid only if @p * edgelist_edge_types.has_value() is true), optional (label, hop) offset values to the - * (D)CSR|(D)CSC offset array (size = # labels * # hops + 1, where # labels = - * std::get<1>(*edgelist_label_offsets) if @p edgelist_label_offsets.has_value() is true and 1 - * otherwise and # hops = std::get<1>(*edgelist_hops) if edgelist_hops.has_value() is true and 1 - * otherwise, valid only if at least one of @p edgelist_label_offsets.has_value() or @p - * edgelist_hops.has_value() is true), renumber_map to query original vertices (size = # unique - * vertices or aggregate # unique vertices for every label), and label offsets to the renumber_map - * (size = std::get<1>(*edgelist_label_offsets) + 1, valid only if @p + * (D)CSR|(D)CSC offset array (size = @p num_labels * @p num_hops + 1, valid only when @p + * edgelist_hops.has_value() or @p edgelist_label_offsets.has_value() is true), renumber_map to + * query original vertices (size = # unique or aggregate # unique_vertices for each label), and + * label offsets to the renumber_map (size = num_labels + 1, valid only if @p * edgelist_label_offsets.has_value() is true). */ template >&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major = true, bool compress_per_hop = false, bool doubly_compress = false, @@ -150,9 +163,12 @@ renumber_and_compress_sampled_edgelist( * we can find the minimum (hop, flag) pairs for every unique vertex ID (hop is the primary key and * flag is the secondary key, flag=major is considered smaller than flag=minor if hop numbers are * same). Vertex IDs with smaller (hop, flag) pairs precede vertex IDs with larger (hop, flag) pairs - * in renumbering. Ordering can be arbitrary among the vertices with the same (hop, flag) pairs. + * in renumbering. Ordering can be arbitrary among the vertices with the same (hop, flag) pairs. If + * @p seed_vertices.has-value() is true, we assume (hop=0, flag=major) for every vertex in @p + * *seed_vertices in renumbering (this is relevant when there are seed vertices with no neighbors). * 2. If @p edgelist_hops is invalid, unique vertex IDs in edge majors precede vertex IDs that - * appear only in edge minors. + * appear only in edge minors. If @p seed_vertices.has_value() is true, vertices in @p + * *seed_vertices precede vertex IDs that appear only in edge minors as well. * 3. If edgelist_label_offsets.has_value() is true, edge lists for different labels will be * renumbered separately. * @@ -180,12 +196,19 @@ renumber_and_compress_sampled_edgelist( * edgelist_srcs.size() if valid). * @param edgelist_edge_types An optional vector storing edgelist edge types (size = @p * edgelist_srcs.size() if valid). - * @param edgelist_hops An optional tuple having a vector storing edge list hop numbers (size = @p - * edgelist_srcs.size() if valid) and the number of hops. The hop vector values should be - * non-decreasing within each label. - * @param edgelist_label_offsets An optional tuple storing a pointer to the array storing label - * offsets to the input edges (size = std::get<1>(*edgelist_label_offsets) + 1) and the number of - * labels. + * @param edgelist_hops An optional vector storing edge list hop numbers (size = @p + * edgelist_srcs.size() if valid). @p edgelist_hops should be valid if @p num_hops >= 2. + * @param seed_vertices An optional pointer to the array storing seed vertices in hop 0. + * @param seed_vertex_label_offsets An optional pointer to the array storing label offsets to the + * seed vertices (size = @p num_labels + 1). @p seed_vertex_label_offsets should be valid if @p + * num_labels >= 2 and @p seed_vertices is valid and invalid otherwise. + * @param edgelist_label_offsets An optional pointer to the array storing label offsets to the input + * edges (size = @p num_labels + 1). @p edgelist_label_offsets should be valid if @p num_labels + * >= 2. + * @param num_labels Number of labels. Labels are considered if @p num_labels >=2 and ignored if @p + * num_labels = 1. + * @param num_hops Number of hops. Hop numbers are considered if @p num_hops >=2 and ignored if @p + * num_hops = 1. * @param src_is_major A flag to determine whether to use the source or destination as the * major key in renumbering and sorting. * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). @@ -193,13 +216,10 @@ renumber_and_compress_sampled_edgelist( * only if @p edgelist_weights.has_value() is true), optional edge IDs (valid only if @p * edgelist_edge_ids.has_value() is true), optional edge types (valid only if @p * edgelist_edge_types.has_value() is true), optional (label, hop) offset values to the renumbered - * and sorted edges (size = # labels * # hops + 1, where # labels = - * std::get<1>(*edgelist_label_offsets) if @p edgelist_label_offsets.has_value() is true and 1 - * otherwise and # hops = std::get<1>(*edgelist_hops) if edgelist_hops.has_value() is true and 1 - * otherwise, valid only if at least one of @p edgelist_label_offsets.has_value() or @p - * edgelist_hops.has_value() is true), renumber_map to query original vertices (size = # unique - * vertices or aggregate # unique vertices for every label), and label offsets to the renumber_map - * (size = std::get<1>(*edgelist_label_offsets) + 1, valid only if @p + * and sorted edges (size = @p num_labels * @p num_hops + 1, valid only when @p + * edgelist_hops.has_value() or @p edgelist_label_offsetes.has_value() is true), renumber_map to + * query original vertices (size = # unique or aggregate # unique vertices for each label), and + * label offsets to the renumber map (size = @p num_labels + 1, valid only if @p * edgelist_label_offsets.has_value() is true). */ template >&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major = true, bool do_expensive_check = false); @@ -253,24 +277,23 @@ renumber_and_sort_sampled_edgelist( * edgelist_srcs.size() if valid). * @param edgelist_edge_types An optional vector storing edgelist edge types (size = @p * edgelist_srcs.size() if valid). - * @param edgelist_hops An optional tuple having a vector storing edge list hop numbers (size = @p - * edgelist_srcs.size() if valid) and the number of hops. The hop vector values should be - * non-decreasing within each label. - * @param edgelist_label_offsets An optional tuple storing a pointer to the array storing label - * offsets to the input edges (size = std::get<1>(*edgelist_label_offsets) + 1) and the number of - * labels. + * @param edgelist_hops An optional vector storing edge list hop numbers (size = @p + * edgelist_srcs.size() if valid). @p edgelist_hops must be valid if @p num_hops >= 2. + * @param edgelist_label_offsets An optional pointer to the array storing label offsets to the input + * edges (size = @p num_labels + 1). @p edgelist_label_offsets must be valid if @p num_labels >= 2. + * @param num_labels Number of labels. Labels are considered if @p num_labels >=2 and ignored if @p + * num_labels = 1. + * @param num_hops Number of hops. Hop numbers are considered if @p num_hops >=2 and ignored if @p + * num_hops = 1. * @param src_is_major A flag to determine whether to use the source or destination as the * major key in renumbering and sorting. * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). * @return Tuple of vectors storing edge sources, edge destinations, optional edge weights (valid * only if @p edgelist_weights.has_value() is true), optional edge IDs (valid only if @p * edgelist_edge_ids.has_value() is true), optional edge types (valid only if @p - * edgelist_edge_types.has_value() is true), and optional (label, hop) offset values to the - * renumbered and sorted edges (size = # labels * # hops + 1, where # labels = - * std::get<1>(*edgelist_label_offsets) if @p edgelist_label_offsets.has_value() is true and 1 - * otherwise and # hops = std::get<1>(*edgelist_hops) if edgelist_hops.has_value() is true and 1 - * otherwise, valid only if at least one of @p edgelist_label_offsets.has_value() or @p - * edgelist_hops.has_value() is true) + * edgelist_edge_types.has_value() is true), and optional (label, hop) offset values to the sorted + * edges (size = @p num_labels * @p num_hops + 1, valid only when @p edgelist_hops.has_value() or @p + * edgelist_label_offsets.has_value() is true). */ template , // srcs std::optional>, // edge IDs std::optional>, // edge types std::optional>> // (label, hop) offsets to the edges -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major = true, - bool do_expensive_check = false); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major = true, + bool do_expensive_check = false); } // namespace cugraph diff --git a/cpp/src/c_api/uniform_neighbor_sampling.cpp b/cpp/src/c_api/uniform_neighbor_sampling.cpp index 44018e088f7..100e81a5bd2 100644 --- a/cpp/src/c_api/uniform_neighbor_sampling.cpp +++ b/cpp/src/c_api/uniform_neighbor_sampling.cpp @@ -178,7 +178,7 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct } // - // Need to renumber personalization_vertices + // Need to renumber start_vertices // cugraph::renumber_local_ext_vertices( handle_, @@ -189,8 +189,6 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct graph_view.local_vertex_partition_range_last(), do_expensive_check_); - bool has_labels = start_vertex_labels_ != nullptr; - auto&& [src, dst, wgt, edge_id, edge_type, hop, edge_label, offsets] = cugraph::uniform_neighbor_sample( handle_, @@ -261,19 +259,21 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct label_hop_offsets, output_renumber_map, renumber_map_offsets) = - cugraph::renumber_and_sort_sampled_edgelist( + cugraph::renumber_and_sort_sampled_edgelist( handle_, std::move(src), std::move(dst), - wgt ? std::move(wgt) : std::nullopt, - edge_id ? std::move(edge_id) : std::nullopt, - edge_type ? std::move(edge_type) : std::nullopt, - hop ? std::make_optional(std::make_tuple(std::move(*hop), fan_out_->size_)) - : std::nullopt, - offsets ? std::make_optional(std::make_tuple( - raft::device_span{offsets->data(), offsets->size()}, - edge_label->size())) + std::move(wgt), + std::move(edge_id), + std::move(edge_type), + std::move(hop), + std::nullopt, + std::nullopt, + offsets ? std::make_optional( + raft::device_span{offsets->data(), offsets->size()}) : std::nullopt, + edge_label ? edge_label->size() : size_t{1}, + hop ? fan_out_->size_ : size_t{1}, src_is_major, do_expensive_check_); @@ -296,19 +296,21 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct label_hop_offsets, output_renumber_map, renumber_map_offsets) = - cugraph::renumber_and_compress_sampled_edgelist( + cugraph::renumber_and_compress_sampled_edgelist( handle_, std::move(src), std::move(dst), - wgt ? std::move(wgt) : std::nullopt, - edge_id ? std::move(edge_id) : std::nullopt, - edge_type ? std::move(edge_type) : std::nullopt, - hop ? std::make_optional(std::make_tuple(std::move(*hop), fan_out_->size_)) - : std::nullopt, - offsets ? std::make_optional(std::make_tuple( - raft::device_span{offsets->data(), offsets->size()}, - edge_label->size())) + std::move(wgt), + std::move(edge_id), + std::move(edge_type), + std::move(hop), + std::nullopt, + std::nullopt, + offsets ? std::make_optional( + raft::device_span{offsets->data(), offsets->size()}) : std::nullopt, + edge_label ? edge_label->size() : size_t{1}, + hop ? fan_out_->size_ : size_t{1}, src_is_major, options_.compress_per_hop_, doubly_compress, @@ -327,21 +329,21 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct } std::tie(src, dst, wgt, edge_id, edge_type, label_hop_offsets) = - cugraph::sort_sampled_edgelist( - handle_, - std::move(src), - std::move(dst), - wgt ? std::move(wgt) : std::nullopt, - edge_id ? std::move(edge_id) : std::nullopt, - edge_type ? std::move(edge_type) : std::nullopt, - hop ? std::make_optional(std::make_tuple(std::move(*hop), fan_out_->size_)) - : std::nullopt, - offsets ? std::make_optional(std::make_tuple( - raft::device_span{offsets->data(), offsets->size()}, - edge_label->size())) - : std::nullopt, - src_is_major, - do_expensive_check_); + cugraph::sort_sampled_edgelist(handle_, + std::move(src), + std::move(dst), + std::move(wgt), + std::move(edge_id), + std::move(edge_type), + std::move(hop), + offsets + ? std::make_optional(raft::device_span{ + offsets->data(), offsets->size()}) + : std::nullopt, + edge_label ? edge_label->size() : size_t{1}, + hop ? fan_out_->size_ : size_t{1}, + src_is_major, + do_expensive_check_); majors.emplace(std::move(src)); minors = std::move(dst); diff --git a/cpp/src/sampling/sampling_post_processing_impl.cuh b/cpp/src/sampling/sampling_post_processing_impl.cuh index 299aae13718..b0b3bb5f4f2 100644 --- a/cpp/src/sampling/sampling_post_processing_impl.cuh +++ b/cpp/src/sampling/sampling_post_processing_impl.cuh @@ -49,7 +49,7 @@ namespace cugraph { namespace { -template +template struct edge_order_t { thrust::optional> edgelist_label_offsets{thrust::nullopt}; thrust::optional> edgelist_hops{thrust::nullopt}; @@ -91,7 +91,7 @@ struct edge_order_t { }; template -struct is_first_in_run_t { +struct is_first_triplet_in_run_t { thrust::optional> edgelist_label_offsets{thrust::nullopt}; thrust::optional> edgelist_hops{thrust::nullopt}; raft::device_span edgelist_majors{}; @@ -154,75 +154,190 @@ template -void check_input_edges( - raft::handle_t const& handle, - rmm::device_uvector const& edgelist_srcs, - rmm::device_uvector const& edgelist_dsts, - std::optional> const& edgelist_weights, - std::optional> const& edgelist_edge_ids, - std::optional> const& edgelist_edge_types, - std::optional, size_t>> const& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool do_expensive_check) +void check_input_edges(raft::handle_t const& handle, + rmm::device_uvector const& edgelist_majors, + rmm::device_uvector const& edgelist_minors, + std::optional> const& edgelist_weights, + std::optional> const& edgelist_edge_ids, + std::optional> const& edgelist_edge_types, + std::optional> const& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool do_expensive_check) { - CUGRAPH_EXPECTS(!edgelist_label_offsets || (std::get<1>(*edgelist_label_offsets) <= - std::numeric_limits::max()), - "Invalid input arguments: current implementation assumes that the number of " - "unique labels is no larger than std::numeric_limits::max()."); - - CUGRAPH_EXPECTS( - !edgelist_label_offsets.has_value() || - (std::get<0>(*edgelist_label_offsets).size() == std::get<1>(*edgelist_label_offsets) + 1), - "Invalid input arguments: if edgelist_label_offsets is valid, " - "std::get<0>(*edgelist_label_offsets).size() (size of the offset array) should be " - "std::get<1>(*edgelist_label_offsets) (number of unique labels) + 1."); - CUGRAPH_EXPECTS( - !edgelist_hops || (std::get<1>(*edgelist_hops) <= std::numeric_limits::max()), - "Invalid input arguments: current implementation assumes that the number of " - "hops is no larger than std::numeric_limits::max()."); - CUGRAPH_EXPECTS(!edgelist_hops || std::get<1>(*edgelist_hops) > 0, - "Invlaid input arguments: number of hops should be larger than 0 if " - "edgelist_hops.has_value() is true."); - - CUGRAPH_EXPECTS( - edgelist_srcs.size() == edgelist_dsts.size(), + edgelist_majors.size() == edgelist_minors.size(), "Invalid input arguments: edgelist_srcs.size() and edgelist_dsts.size() should coincide."); CUGRAPH_EXPECTS( - !edgelist_weights.has_value() || (edgelist_srcs.size() == (*edgelist_weights).size()), - "Invalid input arguments: if edgelist_weights is valid, std::get<0>(*edgelist_weights).size() " - "and edgelist_srcs.size() should coincide."); + !edgelist_weights.has_value() || (edgelist_majors.size() == (*edgelist_weights).size()), + "Invalid input arguments: if edgelist_weights is valid, (*edgelist_weights).size() and " + "edgelist_(srcs|dsts).size() should coincide."); CUGRAPH_EXPECTS( - !edgelist_edge_ids.has_value() || (edgelist_srcs.size() == (*edgelist_edge_ids).size()), - "Invalid input arguments: if edgelist_edge_ids is valid, " - "std::get<0>(*edgelist_edge_ids).size() and edgelist_srcs.size() should coincide."); + !edgelist_edge_ids.has_value() || (edgelist_majors.size() == (*edgelist_edge_ids).size()), + "Invalid input arguments: if edgelist_edge_ids is valid, (*edgelist_edge_ids).size() and " + "edgelist_(srcs|dsts).size() should coincide."); CUGRAPH_EXPECTS( - !edgelist_edge_types.has_value() || (edgelist_srcs.size() == (*edgelist_edge_types).size()), - "Invalid input arguments: if edgelist_edge_types is valid, " - "std::get<0>(*edgelist_edge_types).size() and edgelist_srcs.size() should coincide."); + !edgelist_edge_types.has_value() || (edgelist_majors.size() == (*edgelist_edge_types).size()), + "Invalid input arguments: if edgelist_edge_types is valid, (*edgelist_edge_types).size() and " + "edgelist_(srcs|dsts).size() should coincide."); + CUGRAPH_EXPECTS(!edgelist_hops.has_value() || (edgelist_majors.size() == (*edgelist_hops).size()), + "Invalid input arguments: if edgelist_hops is valid, (*edgelist_hops).size() and " + "edgelist_(srcs|dsts).size() should coincide."); + CUGRAPH_EXPECTS( - !edgelist_hops.has_value() || (edgelist_srcs.size() == std::get<0>(*edgelist_hops).size()), - "Invalid input arguments: if edgelist_hops is valid, std::get<0>(*edgelist_hops).size() and " - "edgelist_srcs.size() should coincide."); + !edgelist_label_offsets.has_value() || ((*edgelist_label_offsets).size() == num_labels + 1), + "Invalid input arguments: if edgelist_label_offsets is valid, (*edgelist_label_offsets).size() " + "(size of the offset array) should be num_labels + 1."); + + if (edgelist_majors.size() > 0) { + CUGRAPH_EXPECTS((num_labels >= 1) && (num_labels <= std::numeric_limits::max()), + "Invalid input arguments: num_labels should be a positive integer and the " + "current implementation assumes that the number of unique labels is no larger " + "than std::numeric_limits::max()."); + CUGRAPH_EXPECTS((num_labels == 1) || edgelist_label_offsets.has_value(), + "Invalid input arguments: edgelist_label_offsets.has_value() should be true if " + "num_labels >= 2."); + + CUGRAPH_EXPECTS( + (num_hops >= 1) && (num_hops <= std::numeric_limits::max()), + "Invalid input arguments: num_hops should be a positive integer and the current " + "implementation " + "assumes that the number of hops is no larger than std::numeric_limits::max()."); + CUGRAPH_EXPECTS( + (num_hops == 1) || edgelist_hops.has_value(), + "Invalid input arguments: edgelist_hops.has_value() should be true if num_hops >= 2."); + } else { + CUGRAPH_EXPECTS( + "num_labels == 0", + "Invalid input arguments: num_labels should be 0 if the input edge list is empty."); + CUGRAPH_EXPECTS( + "num_hops == 0", + "Invalid input arguments: num_hops should be 0 if the input edge list is empty."); + } + + CUGRAPH_EXPECTS((!seed_vertices.has_value() && !seed_vertex_label_offsets.has_value()) || + (seed_vertices.has_value() && + (edgelist_label_offsets.has_value() == seed_vertex_label_offsets.has_value())), + "Invaild input arguments: if seed_vertices.has_value() is false, " + "seed_vertex_label_offsets.has_value() should be false as well. If " + "seed_vertices.has_value( ) is true, seed_vertex_label_offsets.has_value() " + "should coincide with edgelist_label_offsets.has_value()."); + CUGRAPH_EXPECTS( + !seed_vertex_label_offsets.has_value() || + ((*seed_vertex_label_offsets).size() == num_labels + 1), + "Invalid input arguments: if seed_vertex_label_offsets is valid, " + "(*seed_vertex_label_offsets).size() (size of the offset array) should be num_labels + 1."); if (do_expensive_check) { if (edgelist_label_offsets) { CUGRAPH_EXPECTS(thrust::is_sorted(handle.get_thrust_policy(), - std::get<0>(*edgelist_label_offsets).begin(), - std::get<0>(*edgelist_label_offsets).end()), + (*edgelist_label_offsets).begin(), + (*edgelist_label_offsets).end()), "Invalid input arguments: if edgelist_label_offsets is valid, " - "std::get<0>(*edgelist_label_offsets) should be sorted."); - size_t back_element{}; + "*edgelist_label_offsets should be sorted."); + size_t front_element{}; raft::update_host( - &back_element, - std::get<0>(*edgelist_label_offsets).data() + std::get<1>(*edgelist_label_offsets), - size_t{1}, - handle.get_stream()); - handle.get_stream(); + &front_element, (*edgelist_label_offsets).data(), size_t{1}, handle.get_stream()); + size_t back_element{}; + raft::update_host(&back_element, + (*edgelist_label_offsets).data() + num_labels, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); CUGRAPH_EXPECTS( - back_element == edgelist_srcs.size(), + front_element == size_t{0}, + "Invalid input arguments: if edgelist_label_offsets is valid, the first element of " + "*edgelist_label_offsets should be 0."); + CUGRAPH_EXPECTS( + back_element == edgelist_majors.size(), "Invalid input arguments: if edgelist_label_offsets is valid, the last element of " - "std::get<0>(*edgelist_label_offsets) and edgelist_srcs.size() should coincide."); + "*edgelist_label_offsets and edgelist_(srcs|dsts).size() should coincide."); + } + + if (seed_vertices) { + for (size_t i = 0; i < num_labels; ++i) { + rmm::device_uvector this_label_seed_vertices(0, handle.get_stream()); + { + size_t start_offset{0}; + auto end_offset = (*seed_vertices).size(); + if (seed_vertex_label_offsets) { + raft::update_host( + &start_offset, (*seed_vertex_label_offsets).data() + i, 1, handle.get_stream()); + raft::update_host( + &end_offset, (*seed_vertex_label_offsets).data() + (i + 1), 1, handle.get_stream()); + handle.sync_stream(); + } + this_label_seed_vertices.resize(end_offset - start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*seed_vertices).begin() + start_offset, + (*seed_vertices).begin() + end_offset, + this_label_seed_vertices.begin()); + thrust::sort(handle.get_thrust_policy(), + this_label_seed_vertices.begin(), + this_label_seed_vertices.end()); + this_label_seed_vertices.resize( + thrust::distance(this_label_seed_vertices.begin(), + thrust::unique(handle.get_thrust_policy(), + this_label_seed_vertices.begin(), + this_label_seed_vertices.end())), + handle.get_stream()); + } + + rmm::device_uvector this_label_zero_hop_majors(0, handle.get_stream()); + { + size_t start_offset{0}; + auto end_offset = edgelist_majors.size(); + if (edgelist_label_offsets) { + raft::update_host( + &start_offset, (*edgelist_label_offsets).data() + i, 1, handle.get_stream()); + raft::update_host( + &end_offset, (*edgelist_label_offsets).data() + (i + 1), 1, handle.get_stream()); + handle.sync_stream(); + } + this_label_zero_hop_majors.resize(end_offset - start_offset, handle.get_stream()); + if (edgelist_hops) { + this_label_zero_hop_majors.resize( + thrust::distance(this_label_zero_hop_majors.begin(), + thrust::copy_if(handle.get_thrust_policy(), + edgelist_majors.begin() + start_offset, + edgelist_majors.begin() + end_offset, + (*edgelist_hops).begin() + start_offset, + this_label_zero_hop_majors.begin(), + detail::is_equal_t{0})), + handle.get_stream()); + } else { + thrust::copy(handle.get_thrust_policy(), + edgelist_majors.begin() + start_offset, + edgelist_majors.begin() + end_offset, + this_label_zero_hop_majors.begin()); + } + thrust::sort(handle.get_thrust_policy(), + this_label_zero_hop_majors.begin(), + this_label_zero_hop_majors.end()); + this_label_zero_hop_majors.resize( + thrust::distance(this_label_zero_hop_majors.begin(), + thrust::unique(handle.get_thrust_policy(), + this_label_zero_hop_majors.begin(), + this_label_zero_hop_majors.end())), + handle.get_stream()); + } + + rmm::device_uvector zero_hop_majors_minus_seed_vertices( + this_label_zero_hop_majors.size(), handle.get_stream()); + CUGRAPH_EXPECTS(thrust::distance( + zero_hop_majors_minus_seed_vertices.begin(), + thrust::set_difference(handle.get_thrust_policy(), + this_label_zero_hop_majors.begin(), + this_label_zero_hop_majors.end(), + this_label_seed_vertices.begin(), + this_label_seed_vertices.end(), + zero_hop_majors_minus_seed_vertices.begin())) == 0, + "Invalid input arguments: if seed_vertices.has_value() is true, " + "seed_vertices should include all zero-hop majors."); + } } } } @@ -235,70 +350,121 @@ std::tuple> /* label indices */ std::optional> /* label offsets for the output */> compute_min_hop_for_unique_label_vertex_pairs( raft::handle_t const& handle, - raft::device_span vertices, - std::optional> hops, - std::optional> label_indices, - std::optional> label_offsets) + raft::device_span edgelist_vertices, + std::optional> edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets) { auto approx_edges_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * - (1 << 20) /* tuning parameter */; // for segmented sort - - if (label_indices) { - auto num_labels = (*label_offsets).size() - 1; - - rmm::device_uvector tmp_label_indices((*label_indices).size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*label_indices).begin(), - (*label_indices).end(), - tmp_label_indices.begin()); + (1 << 18) /* tuning parameter */; // for segmented sort + if (edgelist_label_offsets) { + rmm::device_uvector tmp_label_indices(0, handle.get_stream()); rmm::device_uvector tmp_vertices(0, handle.get_stream()); std::optional> tmp_hops{std::nullopt}; - if (hops) { - tmp_vertices.resize(vertices.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); - tmp_hops = rmm::device_uvector((*hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), (*tmp_hops).begin()); + auto [h_label_offsets, h_edge_offsets] = + detail::compute_offset_aligned_element_chunks(handle, + *edgelist_label_offsets, + edgelist_vertices.size(), + approx_edges_to_sort_per_iteration); + auto num_chunks = h_label_offsets.size() - 1; - auto triplet_first = thrust::make_zip_iterator( - tmp_label_indices.begin(), tmp_vertices.begin(), (*tmp_hops).begin()); - thrust::sort( - handle.get_thrust_policy(), triplet_first, triplet_first + tmp_label_indices.size()); - auto key_first = thrust::make_zip_iterator(tmp_label_indices.begin(), tmp_vertices.begin()); - auto num_uniques = static_cast( - thrust::distance(key_first, - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - key_first, - key_first + tmp_label_indices.size(), - (*tmp_hops).begin())))); - tmp_label_indices.resize(num_uniques, handle.get_stream()); - tmp_vertices.resize(num_uniques, handle.get_stream()); - (*tmp_hops).resize(num_uniques, handle.get_stream()); - tmp_label_indices.shrink_to_fit(handle.get_stream()); - tmp_vertices.shrink_to_fit(handle.get_stream()); - (*tmp_hops).shrink_to_fit(handle.get_stream()); - } else { - rmm::device_uvector segment_sorted_vertices(vertices.size(), handle.get_stream()); + if (edgelist_hops) { + rmm::device_uvector tmp_indices(edgelist_vertices.size(), handle.get_stream()); + thrust::sequence( + handle.get_thrust_policy(), tmp_indices.begin(), tmp_indices.end(), size_t{0}); - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + // cub::DeviceSegmentedSort currently does not suuport thrust::tuple type keys, sorting in + // chunks still helps in limiting the binary search range and improving memory locality + for (size_t i = 0; i < num_chunks; ++i) { + thrust::sort( + handle.get_thrust_policy(), + tmp_indices.begin() + h_edge_offsets[i], + tmp_indices.begin() + h_edge_offsets[i + 1], + [edgelist_label_offsets = + raft::device_span((*edgelist_label_offsets).data() + h_label_offsets[i], + (h_label_offsets[i + 1] - h_label_offsets[i]) + 1), + edgelist_vertices, + edgelist_hops = *edgelist_hops] __device__(size_t l_idx, size_t r_idx) { + auto l_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), l_idx); + auto r_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), r_idx); + if (l_it != r_it) { return l_it < r_it; } + + auto l_vertex = edgelist_vertices[l_idx]; + auto r_vertex = edgelist_vertices[r_idx]; + if (l_vertex != r_vertex) { return l_vertex < r_vertex; } + + auto l_hop = edgelist_hops[l_idx]; + auto r_hop = edgelist_hops[r_idx]; + return l_hop < r_hop; + }); + } - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_element_chunks( - handle, *label_offsets, vertices.size(), approx_edges_to_sort_per_iteration); - auto num_chunks = h_label_offsets.size() - 1; + tmp_indices.resize( + thrust::distance( + tmp_indices.begin(), + thrust::unique(handle.get_thrust_policy(), + tmp_indices.begin(), + tmp_indices.end(), + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_vertices, + edgelist_hops = *edgelist_hops] __device__(size_t l_idx, size_t r_idx) { + auto l_it = thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + l_idx); + auto r_it = thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + r_idx); + if (l_it != r_it) { return false; } + + auto l_vertex = edgelist_vertices[l_idx]; + auto r_vertex = edgelist_vertices[r_idx]; + return l_vertex == r_vertex; + })), + handle.get_stream()); + tmp_label_indices.resize(tmp_indices.size(), handle.get_stream()); + tmp_vertices.resize(tmp_indices.size(), handle.get_stream()); + tmp_hops = rmm::device_uvector(tmp_indices.size(), handle.get_stream()); + + auto triplet_first = thrust::make_transform_iterator( + tmp_indices.begin(), + cuda::proclaim_return_type>( + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_vertices, + edgelist_hops = *edgelist_hops] __device__(size_t i) { + auto label_idx = static_cast(thrust::distance( + edgelist_label_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), i))); + return thrust::make_tuple(label_idx, edgelist_vertices[i], edgelist_hops[i]); + })); + thrust::copy(handle.get_thrust_policy(), + triplet_first, + triplet_first + tmp_indices.size(), + thrust::make_zip_iterator( + tmp_label_indices.begin(), tmp_vertices.begin(), (*tmp_hops).begin())); + } else { + rmm::device_uvector segment_sorted_vertices(edgelist_vertices.size(), + handle.get_stream()); + + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); for (size_t i = 0; i < num_chunks; ++i) { size_t tmp_storage_bytes{0}; auto offset_first = - thrust::make_transform_iterator((*label_offsets).data() + h_label_offsets[i], + thrust::make_transform_iterator((*edgelist_label_offsets).data() + h_label_offsets[i], detail::shift_left_t{h_edge_offsets[i]}); cub::DeviceSegmentedSort::SortKeys(static_cast(nullptr), tmp_storage_bytes, - vertices.begin() + h_edge_offsets[i], + edgelist_vertices.begin() + h_edge_offsets[i], segment_sorted_vertices.begin() + h_edge_offsets[i], h_edge_offsets[i + 1] - h_edge_offsets[i], h_label_offsets[i + 1] - h_label_offsets[i], @@ -312,7 +478,7 @@ compute_min_hop_for_unique_label_vertex_pairs( cub::DeviceSegmentedSort::SortKeys(d_tmp_storage.data(), tmp_storage_bytes, - vertices.begin() + h_edge_offsets[i], + edgelist_vertices.begin() + h_edge_offsets[i], segment_sorted_vertices.begin() + h_edge_offsets[i], h_edge_offsets[i + 1] - h_edge_offsets[i], h_label_offsets[i + 1] - h_label_offsets[i], @@ -323,27 +489,227 @@ compute_min_hop_for_unique_label_vertex_pairs( d_tmp_storage.resize(0, handle.get_stream()); d_tmp_storage.shrink_to_fit(handle.get_stream()); - auto pair_first = - thrust::make_zip_iterator(tmp_label_indices.begin(), segment_sorted_vertices.begin()); - auto num_uniques = static_cast(thrust::distance( - pair_first, - thrust::unique( - handle.get_thrust_policy(), pair_first, pair_first + tmp_label_indices.size()))); + tmp_label_indices.resize(segment_sorted_vertices.size(), handle.get_stream()); + tmp_vertices.resize(segment_sorted_vertices.size(), handle.get_stream()); + + auto input_pair_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type>( + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_vertices = raft::device_span( + segment_sorted_vertices.data(), segment_sorted_vertices.size())] __device__(size_t i) { + auto label_idx = static_cast(thrust::distance( + edgelist_label_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), i))); + return thrust::make_tuple(label_idx, edgelist_vertices[i]); + })); + auto output_pair_first = + thrust::make_zip_iterator(tmp_label_indices.begin(), tmp_vertices.begin()); + auto num_uniques = + thrust::distance(output_pair_first, + thrust::unique_copy(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + segment_sorted_vertices.size(), + output_pair_first)); tmp_label_indices.resize(num_uniques, handle.get_stream()); - segment_sorted_vertices.resize(num_uniques, handle.get_stream()); + tmp_vertices.resize(num_uniques, handle.get_stream()); tmp_label_indices.shrink_to_fit(handle.get_stream()); - segment_sorted_vertices.shrink_to_fit(handle.get_stream()); + tmp_vertices.shrink_to_fit(handle.get_stream()); + } + + if (seed_vertices) { + /* label segmented sort */ + + rmm::device_uvector segment_sorted_vertices((*seed_vertices).size(), + handle.get_stream()); + + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + size_t tmp_storage_bytes{0}; + + cub::DeviceSegmentedSort::SortKeys(static_cast(nullptr), + tmp_storage_bytes, + (*seed_vertices).begin(), + segment_sorted_vertices.begin(), + (*seed_vertices).size(), + (*seed_vertex_label_offsets).size() - 1, + (*seed_vertex_label_offsets).begin(), + (*seed_vertex_label_offsets).begin() + 1, + handle.get_stream()); + + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + + cub::DeviceSegmentedSort::SortKeys(d_tmp_storage.data(), + tmp_storage_bytes, + (*seed_vertices).begin(), + segment_sorted_vertices.begin(), + (*seed_vertices).size(), + (*seed_vertex_label_offsets).size() - 1, + (*seed_vertex_label_offsets).begin(), + (*seed_vertex_label_offsets).begin() + 1, + handle.get_stream()); + + /* enumerate unique (label, vertex) pairs */ + + rmm::device_uvector unique_seed_vertex_label_indices((*seed_vertices).size(), + handle.get_stream()); + rmm::device_uvector unique_seed_vertices((*seed_vertices).size(), + handle.get_stream()); + auto input_pair_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type>( + [seed_vertex_label_offsets = *seed_vertex_label_offsets, + seed_vertices = raft::device_span( + segment_sorted_vertices.data(), segment_sorted_vertices.size())] __device__(size_t i) { + auto label_idx = static_cast( + thrust::distance(seed_vertex_label_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + seed_vertex_label_offsets.begin() + 1, + seed_vertex_label_offsets.end(), + i))); + return thrust::make_tuple(label_idx, seed_vertices[i]); + })); + auto output_pair_first = thrust::make_zip_iterator(unique_seed_vertex_label_indices.begin(), + unique_seed_vertices.begin()); + auto num_uniques = + thrust::distance(output_pair_first, + thrust::unique_copy(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + segment_sorted_vertices.size(), + output_pair_first)); + unique_seed_vertex_label_indices.resize( + thrust::distance(output_pair_first, + thrust::unique_copy(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + segment_sorted_vertices.size(), + output_pair_first)), + handle.get_stream()); + unique_seed_vertices.resize(unique_seed_vertex_label_indices.size(), handle.get_stream()); + + /* merge with the (label, vertex, min. hop) triplets from the edgelist */ - tmp_vertices = std::move(segment_sorted_vertices); + if (edgelist_hops) { + auto triplet_from_edgelist_first = thrust::make_zip_iterator( + tmp_label_indices.begin(), tmp_vertices.begin(), (*tmp_hops).begin()); + auto key_pair_from_seed_vertex_first = thrust::make_zip_iterator( + unique_seed_vertex_label_indices.begin(), unique_seed_vertices.begin()); + thrust::for_each( + handle.get_thrust_policy(), + key_pair_from_seed_vertex_first, + key_pair_from_seed_vertex_first + unique_seed_vertex_label_indices.size(), + [triplet_from_edgelist_first, + triplet_from_edgelist_last = + triplet_from_edgelist_first + tmp_label_indices.size()] __device__(auto pair) { + auto it = thrust::lower_bound( + thrust::seq, + triplet_from_edgelist_first, + triplet_from_edgelist_last, + thrust::make_tuple(thrust::get<0>(pair), thrust::get<1>(pair), int32_t{0})); + if ((it != triplet_from_edgelist_last) && + (thrust::get<0>(*it) == thrust::get<0>(pair)) && + (thrust::get<1>(*it) == thrust::get<1>(pair))) { + // update min. hop to 0 + if (thrust::get<2>(*it) != int32_t{0}) { thrust::get<2>(*it) = int32_t{0}; } + } + }); + + unique_seed_vertex_label_indices.resize( + thrust::distance( + key_pair_from_seed_vertex_first, + thrust::remove_if( + handle.get_thrust_policy(), + key_pair_from_seed_vertex_first, + key_pair_from_seed_vertex_first + unique_seed_vertices.size(), + [triplet_from_edgelist_first, + triplet_from_edgelist_last = + triplet_from_edgelist_first + tmp_label_indices.size()] __device__(auto pair) { + auto it = thrust::lower_bound( + thrust::seq, + triplet_from_edgelist_first, + triplet_from_edgelist_last, + thrust::make_tuple(thrust::get<0>(pair), thrust::get<1>(pair), int32_t{0})); + return (it != triplet_from_edgelist_last) && + (thrust::get<0>(*it) == thrust::get<0>(pair)) && + (thrust::get<1>(*it) == thrust::get<1>(pair)); + })), + handle.get_stream()); + unique_seed_vertices.resize(unique_seed_vertex_label_indices.size(), handle.get_stream()); + if (unique_seed_vertex_label_indices.size() > 0) { + rmm::device_uvector merged_label_indices( + tmp_label_indices.size() + unique_seed_vertex_label_indices.size(), + handle.get_stream()); + rmm::device_uvector merged_vertices(merged_label_indices.size(), + handle.get_stream()); + rmm::device_uvector merged_hops(merged_label_indices.size(), + handle.get_stream()); + auto triplet_from_seed_vertex_first = + thrust::make_zip_iterator(unique_seed_vertex_label_indices.begin(), + unique_seed_vertices.begin(), + thrust::make_constant_iterator(int32_t{0})); + thrust::merge( + handle.get_thrust_policy(), + triplet_from_edgelist_first, + triplet_from_edgelist_first + tmp_label_indices.size(), + triplet_from_seed_vertex_first, + triplet_from_seed_vertex_first + unique_seed_vertex_label_indices.size(), + thrust::make_zip_iterator( + merged_label_indices.begin(), merged_vertices.begin(), merged_hops.begin())); + tmp_label_indices = std::move(merged_label_indices); + tmp_vertices = std::move(merged_vertices); + tmp_hops = std::move(merged_hops); + } + } else { + auto pair_from_edgelist_first = + thrust::make_zip_iterator(tmp_label_indices.begin(), tmp_vertices.begin()); + auto pair_from_seed_vertex_first = thrust::make_zip_iterator( + unique_seed_vertex_label_indices.begin(), unique_seed_vertices.begin()); + unique_seed_vertex_label_indices.resize( + thrust::distance( + pair_from_seed_vertex_first, + thrust::remove_if( + handle.get_thrust_policy(), + pair_from_seed_vertex_first, + pair_from_seed_vertex_first + unique_seed_vertex_label_indices.size(), + [pair_from_edgelist_first, + pair_from_edgelist_last = + pair_from_edgelist_first + tmp_label_indices.size()] __device__(auto pair) { + auto it = thrust::lower_bound( + thrust::seq, pair_from_edgelist_first, pair_from_edgelist_last, pair); + return (it != pair_from_edgelist_last) && (*it == pair); + })), + handle.get_stream()); + unique_seed_vertices.resize(unique_seed_vertex_label_indices.size(), handle.get_stream()); + if (unique_seed_vertex_label_indices.size() > 0) { + rmm::device_uvector merged_label_indices( + tmp_label_indices.size() + unique_seed_vertex_label_indices.size(), + handle.get_stream()); + rmm::device_uvector merged_vertices(merged_label_indices.size(), + handle.get_stream()); + pair_from_seed_vertex_first = thrust::make_zip_iterator( + unique_seed_vertex_label_indices.begin(), unique_seed_vertices.begin()); + thrust::merge( + handle.get_thrust_policy(), + pair_from_edgelist_first, + pair_from_edgelist_first + tmp_label_indices.size(), + pair_from_seed_vertex_first, + pair_from_seed_vertex_first + unique_seed_vertex_label_indices.size(), + thrust::make_zip_iterator(merged_label_indices.begin(), merged_vertices.begin())); + tmp_label_indices = std::move(merged_label_indices); + tmp_vertices = std::move(merged_vertices); + } + } } - rmm::device_uvector tmp_label_offsets(num_labels + 1, handle.get_stream()); + rmm::device_uvector tmp_label_offsets((*edgelist_label_offsets).size(), + handle.get_stream()); tmp_label_offsets.set_element_to_zero_async(0, handle.get_stream()); thrust::upper_bound(handle.get_thrust_policy(), tmp_label_indices.begin(), tmp_label_indices.end(), thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(num_labels), + thrust::make_counting_iterator(tmp_label_offsets.size() - 1), tmp_label_offsets.begin() + 1); return std::make_tuple(std::move(tmp_label_indices), @@ -351,28 +717,34 @@ compute_min_hop_for_unique_label_vertex_pairs( std::move(tmp_hops), std::move(tmp_label_offsets)); } else { - rmm::device_uvector tmp_vertices(vertices.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); + rmm::device_uvector tmp_vertices(edgelist_vertices.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + edgelist_vertices.begin(), + edgelist_vertices.end(), + tmp_vertices.begin()); + std::optional> tmp_hops{std::nullopt}; - if (hops) { - rmm::device_uvector tmp_hops((*hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), tmp_hops.begin()); + if (edgelist_hops) { + tmp_hops = rmm::device_uvector((*edgelist_hops).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*edgelist_hops).begin(), + (*edgelist_hops).end(), + (*tmp_hops).begin()); auto pair_first = thrust::make_zip_iterator( - tmp_vertices.begin(), tmp_hops.begin()); // vertex is a primary key, hop is a secondary key + tmp_vertices.begin(), + (*tmp_hops).begin()); // vertex is a primary key, hop is a secondary key thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_vertices.size()); tmp_vertices.resize( thrust::distance(tmp_vertices.begin(), thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end(), - tmp_hops.begin()))), + (*tmp_hops).begin()))), handle.get_stream()); - tmp_hops.resize(tmp_vertices.size(), handle.get_stream()); - - return std::make_tuple( - std::nullopt, std::move(tmp_vertices), std::move(tmp_hops), std::nullopt); + (*tmp_hops).resize(tmp_vertices.size(), handle.get_stream()); + tmp_vertices.shrink_to_fit(handle.get_stream()); + (*tmp_hops).shrink_to_fit(handle.get_stream()); } else { thrust::sort(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end()); tmp_vertices.resize( @@ -381,9 +753,109 @@ compute_min_hop_for_unique_label_vertex_pairs( thrust::unique(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end())), handle.get_stream()); tmp_vertices.shrink_to_fit(handle.get_stream()); + } + + if (seed_vertices) { + /* sort and enumerate unique verties */ + + rmm::device_uvector unique_seed_vertices((*seed_vertices).size(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*seed_vertices).begin(), + (*seed_vertices).end(), + unique_seed_vertices.begin()); + thrust::sort( + handle.get_thrust_policy(), unique_seed_vertices.begin(), unique_seed_vertices.end()); + unique_seed_vertices.resize(thrust::distance(unique_seed_vertices.begin(), + thrust::unique(handle.get_thrust_policy(), + unique_seed_vertices.begin(), + unique_seed_vertices.end())), + handle.get_stream()); + + /* merge with the (vertex, min. hop) pairs from the edgelist */ - return std::make_tuple(std::nullopt, std::move(tmp_vertices), std::nullopt, std::nullopt); + if (edgelist_hops) { + auto pair_from_edgelist_first = + thrust::make_zip_iterator(tmp_vertices.begin(), (*tmp_hops).begin()); + thrust::for_each(handle.get_thrust_policy(), + unique_seed_vertices.begin(), + unique_seed_vertices.end(), + [pair_from_edgelist_first, + pair_from_edgelist_last = + pair_from_edgelist_first + tmp_vertices.size()] __device__(auto v) { + auto it = thrust::lower_bound(thrust::seq, + pair_from_edgelist_first, + pair_from_edgelist_last, + thrust::make_tuple(v, int32_t{0})); + if ((it != pair_from_edgelist_last) && (thrust::get<0>(*it) == v)) { + // update min. hop to 0 + if (thrust::get<1>(*it) != int32_t{0}) { + thrust::get<1>(*it) = int32_t{0}; + } + } + }); + + unique_seed_vertices.resize( + thrust::distance(unique_seed_vertices.begin(), + thrust::remove_if( + handle.get_thrust_policy(), + unique_seed_vertices.begin(), + unique_seed_vertices.end(), + [pair_from_edgelist_first, + pair_from_edgelist_last = + pair_from_edgelist_first + tmp_vertices.size()] __device__(auto v) { + auto it = thrust::lower_bound(thrust::seq, + pair_from_edgelist_first, + pair_from_edgelist_last, + thrust::make_tuple(v, int32_t{0})); + return (it != pair_from_edgelist_last) && (thrust::get<0>(*it) == v); + })), + handle.get_stream()); + if (unique_seed_vertices.size() > 0) { + rmm::device_uvector merged_vertices( + tmp_vertices.size() + unique_seed_vertices.size(), handle.get_stream()); + rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); + auto pair_from_seed_vertex_first = thrust::make_zip_iterator( + unique_seed_vertices.begin(), thrust::make_constant_iterator(int32_t{0})); + thrust::merge(handle.get_thrust_policy(), + pair_from_edgelist_first, + pair_from_edgelist_first + tmp_vertices.size(), + pair_from_seed_vertex_first, + pair_from_seed_vertex_first + unique_seed_vertices.size(), + thrust::make_zip_iterator(merged_vertices.begin(), merged_hops.begin())); + tmp_vertices = std::move(merged_vertices); + tmp_hops = std::move(merged_hops); + } + } else { + unique_seed_vertices.resize( + thrust::distance( + unique_seed_vertices.begin(), + thrust::remove_if(handle.get_thrust_policy(), + unique_seed_vertices.begin(), + unique_seed_vertices.end(), + [tmp_vertices = raft::device_span( + tmp_vertices.data(), tmp_vertices.size())] __device__(auto v) { + auto it = thrust::lower_bound( + thrust::seq, tmp_vertices.begin(), tmp_vertices.end(), v); + return (it != tmp_vertices.end()) && (*it == v); + })), + handle.get_stream()); + if (unique_seed_vertices.size() > 0) { + rmm::device_uvector merged_vertices( + tmp_vertices.size() + unique_seed_vertices.size(), handle.get_stream()); + thrust::merge(handle.get_thrust_policy(), + tmp_vertices.begin(), + tmp_vertices.end(), + unique_seed_vertices.begin(), + unique_seed_vertices.end(), + merged_vertices.begin()); + tmp_vertices = std::move(merged_vertices); + } + } } + + return std::make_tuple( + std::nullopt, std::move(tmp_vertices), std::move(tmp_hops), std::nullopt); } } @@ -393,45 +865,32 @@ compute_renumber_map(raft::handle_t const& handle, raft::device_span edgelist_majors, raft::device_span edgelist_minors, std::optional> edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, std::optional> edgelist_label_offsets) { auto approx_edges_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 20) /* tuning parameter */; // for segmented sort - std::optional> edgelist_label_indices{std::nullopt}; - if (edgelist_label_offsets) { - edgelist_label_indices = - detail::expand_sparse_offsets(*edgelist_label_offsets, label_index_t{0}, handle.get_stream()); - } - auto [unique_label_major_pair_label_indices, unique_label_major_pair_vertices, unique_label_major_pair_hops, unique_label_major_pair_label_offsets] = - compute_min_hop_for_unique_label_vertex_pairs( + compute_min_hop_for_unique_label_vertex_pairs( handle, edgelist_majors, edgelist_hops, - edgelist_label_indices ? std::make_optional>( - (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) - : std::nullopt, + seed_vertices, + seed_vertex_label_offsets, edgelist_label_offsets); auto [unique_label_minor_pair_label_indices, unique_label_minor_pair_vertices, unique_label_minor_pair_hops, unique_label_minor_pair_label_offsets] = - compute_min_hop_for_unique_label_vertex_pairs( - handle, - edgelist_minors, - edgelist_hops, - edgelist_label_indices ? std::make_optional>( - (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) - : std::nullopt, - edgelist_label_offsets); - - edgelist_label_indices = std::nullopt; + compute_min_hop_for_unique_label_vertex_pairs( + handle, edgelist_minors, edgelist_hops, std::nullopt, std::nullopt, edgelist_label_offsets); if (edgelist_label_offsets) { auto num_labels = (*edgelist_label_offsets).size() - 1; @@ -640,20 +1099,23 @@ compute_renumber_map(raft::handle_t const& handle, } } -// this function does not reorder edges (the i'th returned edge is the renumbered output of the i'th -// input edge) +// this function does not reorder edges (the i'th returned edge is the renumbered output of the +// i'th input edge) template -std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> -renumber_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_majors, - rmm::device_uvector&& edgelist_minors, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool do_expensive_check) +std::tuple, // edgelist_majors + rmm::device_uvector, // edgelist minors + std::optional>, // seed_vertices, + rmm::device_uvector, // renumber_map + std::optional>> // renumber_map_label_offsets +renumber_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_majors, + rmm::device_uvector&& edgelist_minors, + std::optional> edgelist_hops, + std::optional>&& seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + bool do_expensive_check) { // 1. compute renumber_map @@ -661,12 +1123,12 @@ renumber_sampled_edgelist( handle, raft::device_span(edgelist_majors.data(), edgelist_majors.size()), raft::device_span(edgelist_minors.data(), edgelist_minors.size()), - edgelist_hops ? std::make_optional>( - std::get<0>(*edgelist_hops).data(), std::get<0>(*edgelist_hops).size()) + edgelist_hops, + seed_vertices ? std::make_optional>((*seed_vertices).data(), + (*seed_vertices).size()) : std::nullopt, - edgelist_label_offsets - ? std::make_optional>(std::get<0>(*edgelist_label_offsets)) - : std::nullopt); + seed_vertex_label_offsets, + edgelist_label_offsets); // 2. compute renumber map offsets for each label @@ -686,8 +1148,7 @@ renumber_sampled_edgelist( unique_label_indices.begin(), vertex_counts.begin()); - renumber_map_label_offsets = - rmm::device_uvector(std::get<1>(*edgelist_label_offsets) + 1, handle.get_stream()); + renumber_map_label_offsets = rmm::device_uvector(num_labels + 1, handle.get_stream()); thrust::fill(handle.get_thrust_policy(), (*renumber_map_label_offsets).begin(), (*renumber_map_label_offsets).end(), @@ -724,8 +1185,6 @@ renumber_sampled_edgelist( (*renumber_map_label_indices).resize(0, handle.get_stream()); (*renumber_map_label_indices).shrink_to_fit(handle.get_stream()); - auto num_labels = std::get<0>(*edgelist_label_offsets).size(); - rmm::device_uvector segment_sorted_renumber_map(renumber_map.size(), handle.get_stream()); rmm::device_uvector segment_sorted_new_vertices(new_vertices.size(), @@ -784,27 +1243,30 @@ renumber_sampled_edgelist( new_vertices.shrink_to_fit(handle.get_stream()); d_tmp_storage.shrink_to_fit(handle.get_stream()); - auto edgelist_label_indices = detail::expand_sparse_offsets( - std::get<0>(*edgelist_label_offsets), label_index_t{0}, handle.get_stream()); - auto pair_first = - thrust::make_zip_iterator(edgelist_majors.begin(), edgelist_label_indices.begin()); + thrust::make_zip_iterator(edgelist_majors.begin(), thrust::make_counting_iterator(size_t{0})); thrust::transform( handle.get_thrust_policy(), pair_first, pair_first + edgelist_majors.size(), edgelist_majors.begin(), - [renumber_map_label_offsets = raft::device_span( + [edgelist_label_offsets = *edgelist_label_offsets, + renumber_map_label_offsets = raft::device_span( (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()), old_vertices = raft::device_span(segment_sorted_renumber_map.data(), segment_sorted_renumber_map.size()), new_vertices = raft::device_span( segment_sorted_new_vertices.data(), segment_sorted_new_vertices.size())] __device__(auto pair) { - auto old_vertex = thrust::get<0>(pair); - auto label_index = thrust::get<1>(pair); - auto label_start_offset = renumber_map_label_offsets[label_index]; - auto label_end_offset = renumber_map_label_offsets[label_index + 1]; + auto old_vertex = thrust::get<0>(pair); + auto label_idx = static_cast( + thrust::distance(edgelist_label_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + thrust::get<1>(pair)))); + auto label_start_offset = renumber_map_label_offsets[label_idx]; + auto label_end_offset = renumber_map_label_offsets[label_idx + 1]; auto it = thrust::lower_bound(thrust::seq, old_vertices.begin() + label_start_offset, old_vertices.begin() + label_end_offset, @@ -813,23 +1275,30 @@ renumber_sampled_edgelist( return *(new_vertices.begin() + thrust::distance(old_vertices.begin(), it)); }); - pair_first = thrust::make_zip_iterator(edgelist_minors.begin(), edgelist_label_indices.begin()); + pair_first = + thrust::make_zip_iterator(edgelist_minors.begin(), thrust::make_counting_iterator(size_t{0})); thrust::transform( handle.get_thrust_policy(), pair_first, pair_first + edgelist_minors.size(), edgelist_minors.begin(), - [renumber_map_label_offsets = raft::device_span( + [edgelist_label_offsets = *edgelist_label_offsets, + renumber_map_label_offsets = raft::device_span( (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()), old_vertices = raft::device_span(segment_sorted_renumber_map.data(), segment_sorted_renumber_map.size()), new_vertices = raft::device_span( segment_sorted_new_vertices.data(), segment_sorted_new_vertices.size())] __device__(auto pair) { - auto old_vertex = thrust::get<0>(pair); - auto label_index = thrust::get<1>(pair); - auto label_start_offset = renumber_map_label_offsets[label_index]; - auto label_end_offset = renumber_map_label_offsets[label_index + 1]; + auto old_vertex = thrust::get<0>(pair); + auto label_idx = static_cast( + thrust::distance(edgelist_label_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + thrust::get<1>(pair)))); + auto label_start_offset = renumber_map_label_offsets[label_idx]; + auto label_end_offset = renumber_map_label_offsets[label_idx + 1]; auto it = thrust::lower_bound(thrust::seq, old_vertices.begin() + label_start_offset, old_vertices.begin() + label_end_offset, @@ -837,6 +1306,40 @@ renumber_sampled_edgelist( assert(*it == old_vertex); return new_vertices[thrust::distance(old_vertices.begin(), it)]; }); + + if (seed_vertices) { + pair_first = thrust::make_zip_iterator((*seed_vertices).begin(), + thrust::make_counting_iterator(size_t{0})); + thrust::transform( + handle.get_thrust_policy(), + pair_first, + pair_first + (*seed_vertices).size(), + (*seed_vertices).begin(), + [seed_vertex_label_offsets = *seed_vertex_label_offsets, + renumber_map_label_offsets = raft::device_span( + (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()), + old_vertices = raft::device_span(segment_sorted_renumber_map.data(), + segment_sorted_renumber_map.size()), + new_vertices = raft::device_span( + segment_sorted_new_vertices.data(), + segment_sorted_new_vertices.size())] __device__(auto pair) { + auto old_vertex = thrust::get<0>(pair); + auto label_idx = static_cast( + thrust::distance(seed_vertex_label_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + seed_vertex_label_offsets.begin() + 1, + seed_vertex_label_offsets.end(), + thrust::get<1>(pair)))); + auto label_start_offset = renumber_map_label_offsets[label_idx]; + auto label_end_offset = renumber_map_label_offsets[label_idx + 1]; + auto it = thrust::lower_bound(thrust::seq, + old_vertices.begin() + label_start_offset, + old_vertices.begin() + label_end_offset, + old_vertex); + assert(*it == old_vertex); + return new_vertices[thrust::distance(old_vertices.begin(), it)]; + }); + } } else { kv_store_t kv_store(renumber_map.begin(), renumber_map.end(), @@ -850,10 +1353,18 @@ renumber_sampled_edgelist( edgelist_majors.begin(), edgelist_majors.end(), edgelist_majors.begin(), handle.get_stream()); kv_store_view.find( edgelist_minors.begin(), edgelist_minors.end(), edgelist_minors.begin(), handle.get_stream()); + + if (seed_vertices) { + kv_store_view.find((*seed_vertices).begin(), + (*seed_vertices).end(), + (*seed_vertices).begin(), + handle.get_stream()); + } } return std::make_tuple(std::move(edgelist_majors), std::move(edgelist_minors), + std::move(seed_vertices), std::move(renumber_map), std::move(renumber_map_label_offsets)); } @@ -886,16 +1397,15 @@ std::tuple, std::optional>, std::optional>, std::optional>, - std::optional, size_t>>> -sort_sampled_edge_tuples( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_majors, - rmm::device_uvector&& edgelist_minors, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets) + std::optional>> +sort_sampled_edge_tuples(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_majors, + rmm::device_uvector&& edgelist_minors, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets) { std::vector h_label_offsets{}; std::vector h_edge_offsets{}; @@ -905,11 +1415,8 @@ sort_sampled_edge_tuples( static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 20) /* tuning parameter */; // for sorts in chunks - std::tie(h_label_offsets, h_edge_offsets) = - detail::compute_offset_aligned_element_chunks(handle, - std::get<0>(*edgelist_label_offsets), - edgelist_majors.size(), - approx_edges_to_sort_per_iteration); + std::tie(h_label_offsets, h_edge_offsets) = detail::compute_offset_aligned_element_chunks( + handle, *edgelist_label_offsets, edgelist_majors.size(), approx_edges_to_sort_per_iteration); } else { h_label_offsets = {0, 1}; h_edge_offsets = {0, edgelist_majors.size()}; @@ -920,13 +1427,13 @@ sort_sampled_edge_tuples( rmm::device_uvector indices(h_edge_offsets[i + 1] - h_edge_offsets[i], handle.get_stream()); thrust::sequence(handle.get_thrust_policy(), indices.begin(), indices.end(), size_t{0}); - edge_order_t edge_order_comp{ + edge_order_t edge_order_comp{ edgelist_label_offsets ? thrust::make_optional>( - std::get<0>(*edgelist_label_offsets).data() + h_label_offsets[i], + (*edgelist_label_offsets).data() + h_label_offsets[i], (h_label_offsets[i + 1] - h_label_offsets[i]) + 1) : thrust::nullopt, edgelist_hops ? thrust::make_optional>( - std::get<0>(*edgelist_hops).data() + h_edge_offsets[i], indices.size()) + (*edgelist_hops).data() + h_edge_offsets[i], indices.size()) : thrust::nullopt, raft::device_span(edgelist_majors.data() + h_edge_offsets[i], indices.size()), raft::device_span(edgelist_minors.data() + h_edge_offsets[i], @@ -955,10 +1462,8 @@ sort_sampled_edge_tuples( } if (edgelist_hops) { - permute_array(handle, - indices.begin(), - indices.end(), - std::get<0>(*edgelist_hops).begin() + h_edge_offsets[i]); + permute_array( + handle, indices.begin(), indices.end(), (*edgelist_hops).begin() + h_edge_offsets[i]); } } @@ -982,10 +1487,10 @@ std::tuple>, // dcsr/dcsc major std::optional>, // weights std::optional>, // edge IDs std::optional>, // edge types - std::optional>, // (label, hop) offsets to the (d)csr/(d)csc - // offset array - rmm::device_uvector, // renumber map - std::optional>> // label offsets to the renumber map + std::optional>, // (label, hop) offsets to the + // (d)csr/(d)csc offset array + rmm::device_uvector, // renumber map + std::optional>> // label offsets to the renumber map renumber_and_compress_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, @@ -993,8 +1498,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -1002,19 +1511,23 @@ renumber_and_compress_sampled_edgelist( { using label_index_t = uint32_t; - auto num_labels = edgelist_label_offsets ? std::get<1>(*edgelist_label_offsets) : size_t{1}; - auto num_hops = edgelist_hops ? std::get<1>(*edgelist_hops) : size_t{1}; + auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); + auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); // 1. check input arguments check_input_edges(handle, - edgelist_srcs, - edgelist_dsts, + edgelist_majors, + edgelist_minors, edgelist_weights, edgelist_edge_ids, edgelist_edge_types, edgelist_hops, + seed_vertices, + seed_vertex_label_offsets, edgelist_label_offsets, + num_labels, + num_hops, do_expensive_check); CUGRAPH_EXPECTS( @@ -1026,22 +1539,33 @@ renumber_and_compress_sampled_edgelist( // 2. renumber - auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); - auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); - + std::optional> renumbered_seed_vertices{std::nullopt}; + if (seed_vertices) { + renumbered_seed_vertices = + rmm::device_uvector((*seed_vertices).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*seed_vertices).begin(), + (*seed_vertices).end(), + (*renumbered_seed_vertices).begin()); + } rmm::device_uvector renumber_map(0, handle.get_stream()); std::optional> renumber_map_label_offsets{std::nullopt}; - std::tie(edgelist_majors, edgelist_minors, renumber_map, renumber_map_label_offsets) = + std::tie(edgelist_majors, + edgelist_minors, + renumbered_seed_vertices, + renumber_map, + renumber_map_label_offsets) = renumber_sampled_edgelist( handle, std::move(edgelist_majors), std::move(edgelist_minors), - edgelist_hops ? std::make_optional(std::make_tuple( - raft::device_span(std::get<0>(*edgelist_hops).data(), - std::get<0>(*edgelist_hops).size()), - num_hops)) + edgelist_hops ? std::make_optional(raft::device_span((*edgelist_hops).data(), + (*edgelist_hops).size())) : std::nullopt, + std::move(renumbered_seed_vertices), + seed_vertex_label_offsets, edgelist_label_offsets, + num_labels, do_expensive_check); // 3. sort by ((l), (h), major, minor) @@ -1060,6 +1584,20 @@ renumber_and_compress_sampled_edgelist( std::move(edgelist_hops), edgelist_label_offsets); + if (renumbered_seed_vertices) { + if (seed_vertex_label_offsets) { + auto label_indices = detail::expand_sparse_offsets( + *seed_vertex_label_offsets, label_index_t{0}, handle.get_stream()); + auto pair_first = + thrust::make_zip_iterator(label_indices.begin(), (*renumbered_seed_vertices).begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + label_indices.size()); + } else { + thrust::sort(handle.get_thrust_policy(), + (*renumbered_seed_vertices).begin(), + (*renumbered_seed_vertices).end()); + } + } + if (do_expensive_check) { if (!compress_per_hop && edgelist_hops) { rmm::device_uvector min_vertices(num_labels * num_hops, handle.get_stream()); @@ -1068,10 +1606,9 @@ renumber_and_compress_sampled_edgelist( auto label_index_first = thrust::make_transform_iterator( thrust::make_counting_iterator(size_t{0}), optionally_compute_label_index_t{ - edgelist_label_offsets ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) + edgelist_label_offsets ? thrust::make_optional(*edgelist_label_offsets) : thrust::nullopt}); - auto input_key_first = - thrust::make_zip_iterator(label_index_first, std::get<0>(*edgelist_hops).begin()); + auto input_key_first = thrust::make_zip_iterator(label_index_first, (*edgelist_hops).begin()); rmm::device_uvector unique_key_label_indices(min_vertices.size(), handle.get_stream()); rmm::device_uvector unique_key_hops(min_vertices.size(), handle.get_stream()); @@ -1097,6 +1634,34 @@ renumber_and_compress_sampled_edgelist( max_vertices.begin(), thrust::equal_to>{}, thrust::maximum{}); + + if (renumbered_seed_vertices) { + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(num_labels), + [seed_vertices = raft::device_span((*renumbered_seed_vertices).data(), + (*renumbered_seed_vertices).size()), + seed_vertex_label_offsets = detail::to_thrust_optional(seed_vertex_label_offsets), + num_hops, + min_vertices = raft::device_span(min_vertices.data(), min_vertices.size()), + max_vertices = raft::device_span( + max_vertices.data(), max_vertices.size())] __device__(size_t l_idx) { + size_t label_start_offset{0}; + auto label_end_offset = seed_vertices.size(); + if (seed_vertex_label_offsets) { + label_start_offset = (*seed_vertex_label_offsets)[l_idx]; + label_end_offset = (*seed_vertex_label_offsets)[l_idx + 1]; + } + if (label_start_offset < label_end_offset) { + min_vertices[l_idx * num_hops] = + cuda::std::min(min_vertices[l_idx * num_hops], seed_vertices[label_start_offset]); + max_vertices[l_idx * num_hops] = + cuda::std::max(max_vertices[l_idx * num_hops], seed_vertices[label_end_offset - 1]); + } + }); + } + if (num_unique_keys > 1) { auto num_invalids = thrust::count_if( handle.get_thrust_policy(), @@ -1116,10 +1681,13 @@ renumber_and_compress_sampled_edgelist( return false; } }); - CUGRAPH_EXPECTS(num_invalids == 0, - "Invalid input arguments: if @p compress_per_hop is false and @p " - "edgelist_hops.has_value() is true, the minimum majors with hop N + 1 " - "should be larger than the maximum majors with hop N after renumbering."); + CUGRAPH_EXPECTS( + num_invalids == 0, + "Invalid input arguments: if both compress_per_hop is false and " + "edgelist_hops.has_value() is true, majors should be non-decreasing within each label " + "after renumbering and sorting by (hop, major, minor). Also, majors in hop N should not " + "appear in any of the previous hops. This condition is satisfied if majors in hop N + 1 " + "does not have any vertices from the previous hops excluding the minors from hop N."); } } } @@ -1131,11 +1699,10 @@ renumber_and_compress_sampled_edgelist( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(edgelist_majors.size()), - is_first_in_run_t{ - edgelist_label_offsets ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) - : thrust::nullopt, + is_first_triplet_in_run_t{ + detail::to_thrust_optional(edgelist_label_offsets), edgelist_hops ? thrust::make_optional>( - std::get<0>(*edgelist_hops).data(), std::get<0>(*edgelist_hops).size()) + (*edgelist_hops).data(), (*edgelist_hops).size()) : thrust::nullopt, raft::device_span( edgelist_majors.data(), @@ -1155,11 +1722,11 @@ renumber_and_compress_sampled_edgelist( if (edgelist_label_offsets) { auto label_index_first = thrust::make_transform_iterator( thrust::make_counting_iterator(size_t{0}), - compute_label_index_t{std::get<0>(*edgelist_label_offsets)}); + compute_label_index_t{*edgelist_label_offsets}); if (edgelist_hops) { auto input_key_first = thrust::make_zip_iterator( - label_index_first, std::get<0>(*edgelist_hops).begin(), edgelist_majors.begin()); + label_index_first, (*edgelist_hops).begin(), edgelist_majors.begin()); auto output_key_first = thrust::make_zip_iterator((*compressed_label_indices).begin(), (*compressed_hops).begin(), compressed_nzd_vertices.begin()); @@ -1183,7 +1750,7 @@ renumber_and_compress_sampled_edgelist( } else { if (edgelist_hops) { auto input_key_first = - thrust::make_zip_iterator(std::get<0>(*edgelist_hops).begin(), edgelist_majors.begin()); + thrust::make_zip_iterator((*edgelist_hops).begin(), edgelist_majors.begin()); auto output_key_first = thrust::make_zip_iterator((*compressed_hops).begin(), compressed_nzd_vertices.begin()); thrust::reduce_by_key(handle.get_thrust_policy(), @@ -1208,8 +1775,8 @@ renumber_and_compress_sampled_edgelist( compressed_offsets.end(), compressed_offsets.begin()); - // 5. update compressed_offsets to include zero degree vertices (if doubly_compress is false) and - // compressed_offset_label_hop_offsets (if edgelist_label_offsets.has_value() or + // 5. update compressed_offsets to include zero degree vertices (if doubly_compress is false) + // and compressed_offset_label_hop_offsets (if edgelist_label_offsets.has_value() or // edgelist_hops.has_value() is true) std::optional> compressed_offset_label_hop_offsets{std::nullopt}; @@ -1217,10 +1784,7 @@ renumber_and_compress_sampled_edgelist( if (edgelist_label_offsets || edgelist_hops) { rmm::device_uvector offset_array_offsets(num_labels * num_hops + 1, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), - offset_array_offsets.begin(), - offset_array_offsets.end(), - size_t{0}); + offset_array_offsets.set_element_to_zero_async(0, handle.get_stream()); if (edgelist_label_offsets) { if (edgelist_hops) { @@ -1265,54 +1829,75 @@ renumber_and_compress_sampled_edgelist( handle.get_thrust_policy(), major_vertex_counts.begin(), major_vertex_counts.end(), - [edgelist_label_offsets = edgelist_label_offsets - ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) - : thrust::nullopt, - edgelist_hops = edgelist_hops - ? thrust::make_optional>( - std::get<0>(*edgelist_hops).data(), std::get<0>(*edgelist_hops).size()) - : thrust::nullopt, + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + edgelist_hops = edgelist_hops ? thrust::make_optional>( + (*edgelist_hops).data(), (*edgelist_hops).size()) + : thrust::nullopt, edgelist_majors = raft::device_span(edgelist_majors.data(), edgelist_majors.size()), + seed_vertices = renumbered_seed_vertices + ? thrust::make_optional>( + (*renumbered_seed_vertices).data(), (*renumbered_seed_vertices).size()) + : thrust::nullopt, + seed_vertex_label_offsets = detail::to_thrust_optional(seed_vertex_label_offsets), num_hops, compress_per_hop] __device__(size_t i) { - size_t start_offset{0}; - auto end_offset = edgelist_majors.size(); - auto label_start_offset = start_offset; - auto label_end_offset = end_offset; - - if (edgelist_label_offsets) { - auto l_idx = static_cast(i / num_hops); - start_offset = (*edgelist_label_offsets)[l_idx]; - end_offset = (*edgelist_label_offsets)[l_idx + 1]; - label_start_offset = start_offset; - label_end_offset = end_offset; - } + vertex_t num_vertices_from_edgelist{0}; + { + size_t start_offset{0}; + auto end_offset = edgelist_majors.size(); + auto label_start_offset = start_offset; + auto label_end_offset = end_offset; + + if (edgelist_label_offsets) { + auto l_idx = static_cast(i / num_hops); + start_offset = (*edgelist_label_offsets)[l_idx]; + end_offset = (*edgelist_label_offsets)[l_idx + 1]; + label_start_offset = start_offset; + label_end_offset = end_offset; + } - if (num_hops > 1) { - auto h = static_cast(i % num_hops); - auto lower_it = thrust::lower_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - auto upper_it = thrust::upper_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - start_offset = static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); - end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); - } - if (compress_per_hop) { - return (start_offset < end_offset) ? (edgelist_majors[end_offset - 1] + 1) : vertex_t{0}; - } else { - if (end_offset != label_end_offset) { - return edgelist_majors[end_offset]; - } else if (label_start_offset < label_end_offset) { - return edgelist_majors[end_offset - 1] + 1; + if (num_hops > 1) { + auto h = static_cast(i % num_hops); + auto lower_it = thrust::lower_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + auto upper_it = thrust::upper_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + start_offset = + static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); + end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); + } + if (compress_per_hop) { + if (start_offset < end_offset) + num_vertices_from_edgelist = edgelist_majors[end_offset - 1] + 1; } else { - return vertex_t{0}; + if (end_offset != label_end_offset) { + num_vertices_from_edgelist = edgelist_majors[end_offset]; + } else if (label_start_offset < label_end_offset) { + num_vertices_from_edgelist = edgelist_majors[end_offset - 1] + 1; + } } } + + vertex_t num_vertices_from_seed_vertices{0}; + if (seed_vertices && (!compress_per_hop || (i % num_hops == 0))) { + size_t label_start_offset{0}; + auto label_end_offset = (*seed_vertices).size(); + if (seed_vertex_label_offsets) { + auto l_idx = static_cast(i / num_hops); + label_start_offset = (*seed_vertex_label_offsets)[l_idx]; + label_end_offset = (*seed_vertex_label_offsets)[l_idx + 1]; + } + if (label_start_offset < label_end_offset) { + num_vertices_from_seed_vertices = (*seed_vertices)[label_end_offset - 1] + 1; + } + } + + return cuda::std::max(num_vertices_from_edgelist, num_vertices_from_seed_vertices); }); std::optional> minor_vertex_counts{std::nullopt}; @@ -1549,46 +2134,62 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check) { using label_index_t = uint32_t; - auto num_labels = edgelist_label_offsets ? std::get<1>(*edgelist_label_offsets) : size_t{1}; - auto num_hops = edgelist_hops ? std::get<1>(*edgelist_hops) : size_t{1}; + auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); + auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); // 1. check input arguments check_input_edges(handle, - edgelist_srcs, - edgelist_dsts, + edgelist_majors, + edgelist_minors, edgelist_weights, edgelist_edge_ids, edgelist_edge_types, edgelist_hops, + seed_vertices, + seed_vertex_label_offsets, edgelist_label_offsets, + num_labels, + num_hops, do_expensive_check); // 2. renumber - auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); - auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); - + std::optional> renumbered_seed_vertices{std::nullopt}; + if (seed_vertices) { + renumbered_seed_vertices = + rmm::device_uvector((*seed_vertices).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*seed_vertices).begin(), + (*seed_vertices).end(), + (*renumbered_seed_vertices).begin()); + } rmm::device_uvector renumber_map(0, handle.get_stream()); std::optional> renumber_map_label_offsets{std::nullopt}; - std::tie(edgelist_majors, edgelist_minors, renumber_map, renumber_map_label_offsets) = + std::tie( + edgelist_majors, edgelist_minors, std::ignore, renumber_map, renumber_map_label_offsets) = renumber_sampled_edgelist( handle, std::move(edgelist_majors), std::move(edgelist_minors), - edgelist_hops ? std::make_optional(std::make_tuple( - raft::device_span(std::get<0>(*edgelist_hops).data(), - std::get<0>(*edgelist_hops).size()), - num_hops)) + edgelist_hops ? std::make_optional(raft::device_span((*edgelist_hops).data(), + (*edgelist_hops).size())) : std::nullopt, + std::move(renumbered_seed_vertices), + seed_vertex_label_offsets, edgelist_label_offsets, + num_labels, do_expensive_check); // 3. sort by ((l), (h), major, minor) @@ -1617,47 +2218,44 @@ renumber_and_sort_sampled_edgelist( (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), size_t{0}); - // FIXME: the device lambda should be placed in cuda::proclaim_return_type() - // once we update CCCL version to 2.x thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(num_labels * num_hops), (*edgelist_label_hop_offsets).begin(), - [edgelist_label_offsets = edgelist_label_offsets - ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) - : thrust::nullopt, - edgelist_hops = edgelist_hops - ? thrust::make_optional>( - std::get<0>(*edgelist_hops).data(), std::get<0>(*edgelist_hops).size()) - : thrust::nullopt, - num_hops, - num_edges = edgelist_majors.size()] __device__(size_t i) { - size_t start_offset{0}; - auto end_offset = num_edges; - - if (edgelist_label_offsets) { - auto l_idx = static_cast(i / num_hops); - start_offset = (*edgelist_label_offsets)[l_idx]; - end_offset = (*edgelist_label_offsets)[l_idx + 1]; - } + cuda::proclaim_return_type( + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + edgelist_hops = edgelist_hops ? thrust::make_optional>( + (*edgelist_hops).data(), (*edgelist_hops).size()) + : thrust::nullopt, + num_hops, + num_edges = edgelist_majors.size()] __device__(size_t i) { + size_t start_offset{0}; + auto end_offset = num_edges; + + if (edgelist_label_offsets) { + auto l_idx = static_cast(i / num_hops); + start_offset = (*edgelist_label_offsets)[l_idx]; + end_offset = (*edgelist_label_offsets)[l_idx + 1]; + } - if (edgelist_hops) { - auto h = static_cast(i % num_hops); - auto lower_it = thrust::lower_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - auto upper_it = thrust::upper_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - start_offset = static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); - end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); - } + if (edgelist_hops) { + auto h = static_cast(i % num_hops); + auto lower_it = thrust::lower_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + auto upper_it = thrust::upper_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + start_offset = + static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); + end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); + } - return end_offset - start_offset; - }); + return end_offset - start_offset; + })); thrust::exclusive_scan(handle.get_thrust_policy(), (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), @@ -1686,40 +2284,42 @@ std::tuple, // srcs std::optional>, // edge IDs std::optional>, // edge types std::optional>> // (label, hop) offsets to the edges -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check) +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check) { using label_index_t = uint32_t; - auto num_labels = edgelist_label_offsets ? std::get<1>(*edgelist_label_offsets) : size_t{1}; - auto num_hops = edgelist_hops ? std::get<1>(*edgelist_hops) : size_t{1}; + auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); + auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); // 1. check input arguments - check_input_edges(handle, - edgelist_srcs, - edgelist_dsts, - edgelist_weights, - edgelist_edge_ids, - edgelist_edge_types, - edgelist_hops, - edgelist_label_offsets, - do_expensive_check); + check_input_edges(handle, + edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops, + std::nullopt, + std::nullopt, + edgelist_label_offsets, + num_labels, + num_hops, + do_expensive_check); // 2. sort by ((l), (h), major, minor) - auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); - auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); - std::tie(edgelist_majors, edgelist_minors, edgelist_weights, @@ -1744,47 +2344,44 @@ sort_sampled_edgelist( (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), size_t{0}); - // FIXME: the device lambda should be placed in cuda::proclaim_return_type() - // once we update CCCL version to 2.x thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(num_labels * num_hops), (*edgelist_label_hop_offsets).begin(), - [edgelist_label_offsets = edgelist_label_offsets - ? thrust::make_optional(std::get<0>(*edgelist_label_offsets)) - : thrust::nullopt, - edgelist_hops = edgelist_hops - ? thrust::make_optional>( - std::get<0>(*edgelist_hops).data(), std::get<0>(*edgelist_hops).size()) - : thrust::nullopt, - num_hops, - num_edges = edgelist_majors.size()] __device__(size_t i) { - size_t start_offset{0}; - auto end_offset = num_edges; - - if (edgelist_label_offsets) { - auto l_idx = static_cast(i / num_hops); - start_offset = (*edgelist_label_offsets)[l_idx]; - end_offset = (*edgelist_label_offsets)[l_idx + 1]; - } + cuda::proclaim_return_type( + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + edgelist_hops = edgelist_hops ? thrust::make_optional>( + (*edgelist_hops).data(), (*edgelist_hops).size()) + : thrust::nullopt, + num_hops, + num_edges = edgelist_majors.size()] __device__(size_t i) { + size_t start_offset{0}; + auto end_offset = num_edges; + + if (edgelist_label_offsets) { + auto l_idx = static_cast(i / num_hops); + start_offset = (*edgelist_label_offsets)[l_idx]; + end_offset = (*edgelist_label_offsets)[l_idx + 1]; + } - if (edgelist_hops) { - auto h = static_cast(i % num_hops); - auto lower_it = thrust::lower_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - auto upper_it = thrust::upper_bound(thrust::seq, - (*edgelist_hops).begin() + start_offset, - (*edgelist_hops).begin() + end_offset, - h); - start_offset = static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); - end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); - } + if (edgelist_hops) { + auto h = static_cast(i % num_hops); + auto lower_it = thrust::lower_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + auto upper_it = thrust::upper_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + start_offset = + static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); + end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); + } - return end_offset - start_offset; - }); + return end_offset - start_offset; + })); thrust::exclusive_scan(handle.get_thrust_policy(), (*edgelist_label_hop_offsets).begin(), (*edgelist_label_hop_offsets).end(), diff --git a/cpp/src/sampling/sampling_post_processing_sg.cu b/cpp/src/sampling/sampling_post_processing_sg.cu index 5a243c9cb6b..3c6734559ed 100644 --- a/cpp/src/sampling/sampling_post_processing_sg.cu +++ b/cpp/src/sampling/sampling_post_processing_sg.cu @@ -36,8 +36,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -59,8 +63,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -82,8 +90,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -105,8 +117,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -128,8 +144,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -151,8 +171,12 @@ renumber_and_compress_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool compress_per_hop, bool doubly_compress, @@ -173,8 +197,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -193,8 +221,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -213,8 +245,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -233,8 +269,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -253,8 +293,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -273,8 +317,12 @@ renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_weights, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, bool src_is_major, bool do_expensive_check); @@ -284,17 +332,18 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); template std::tuple, rmm::device_uvector, @@ -302,17 +351,18 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); template std::tuple, rmm::device_uvector, @@ -320,17 +370,18 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); template std::tuple, rmm::device_uvector, @@ -338,17 +389,18 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); template std::tuple, rmm::device_uvector, @@ -356,17 +408,18 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); template std::tuple, rmm::device_uvector, @@ -374,16 +427,17 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -sort_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional>&& edgelist_weights, - std::optional>&& edgelist_edge_ids, - std::optional>&& edgelist_edge_types, - std::optional, size_t>>&& edgelist_hops, - std::optional, size_t>> edgelist_label_offsets, - bool src_is_major, - bool do_expensive_check); +sort_sampled_edgelist(raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_hops, + bool src_is_major, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/tests/sampling/sampling_post_processing_test.cu b/cpp/tests/sampling/sampling_post_processing_test.cu index c93994ddfad..c87cc5b960b 100644 --- a/cpp/tests/sampling/sampling_post_processing_test.cu +++ b/cpp/tests/sampling/sampling_post_processing_test.cu @@ -47,6 +47,7 @@ struct SamplingPostProcessing_Usecase { bool sample_with_replacement{false}; bool src_is_major{true}; + bool renumber_with_seeds{false}; bool compress_per_hop{false}; bool doubly_compress{false}; bool check_correctness{true}; @@ -175,6 +176,7 @@ bool compare_edgelist(raft::handle_t const& handle, template bool check_renumber_map_invariants( raft::handle_t const& handle, + std::optional> starting_vertices, raft::device_span org_edgelist_srcs, raft::device_span org_edgelist_dsts, std::optional> org_edgelist_hops, @@ -193,6 +195,15 @@ bool check_renumber_map_invariants( org_edgelist_majors.begin(), org_edgelist_majors.end(), unique_majors.begin()); + if (starting_vertices) { + auto old_size = unique_majors.size(); + unique_majors.resize(old_size + (*starting_vertices).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*starting_vertices).begin(), + (*starting_vertices).end(), + unique_majors.begin() + old_size); + } + std::optional> unique_major_hops = org_edgelist_hops ? std::make_optional>( (*org_edgelist_hops).size(), handle.get_stream()) @@ -202,6 +213,14 @@ bool check_renumber_map_invariants( (*org_edgelist_hops).begin(), (*org_edgelist_hops).end(), (*unique_major_hops).begin()); + if (starting_vertices) { + auto old_size = (*unique_major_hops).size(); + (*unique_major_hops).resize(old_size + (*starting_vertices).size(), handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), + (*unique_major_hops).begin() + old_size, + (*unique_major_hops).end(), + int32_t{0}); + } auto pair_first = thrust::make_zip_iterator(unique_majors.begin(), (*unique_major_hops).begin()); @@ -476,6 +495,11 @@ class Tests_SamplingPostProcessing ? std::make_optional>( starting_vertices.size(), handle.get_stream()) : std::nullopt; + auto starting_vertex_label_offsets = + (sampling_post_processing_usecase.num_labels > 1) + ? std::make_optional>( + sampling_post_processing_usecase.num_labels + 1, handle.get_stream()) + : std::nullopt; if (starting_vertex_labels) { thrust::tabulate( handle.get_thrust_policy(), @@ -483,6 +507,12 @@ class Tests_SamplingPostProcessing (*starting_vertex_labels).end(), [num_seeds_per_label = sampling_post_processing_usecase.num_seeds_per_label] __device__( size_t i) { return static_cast(i / num_seeds_per_label); }); + thrust::tabulate( + handle.get_thrust_policy(), + (*starting_vertex_label_offsets).begin(), + (*starting_vertex_label_offsets).end(), + [num_seeds_per_label = sampling_post_processing_usecase.num_seeds_per_label] __device__( + size_t i) { return num_seeds_per_label * i; }); } rmm::device_uvector org_edgelist_srcs(0, handle.get_stream()); @@ -530,10 +560,6 @@ class Tests_SamplingPostProcessing std::swap(org_edgelist_srcs, org_edgelist_dsts); } - starting_vertices.resize(0, handle.get_stream()); - starting_vertices.shrink_to_fit(handle.get_stream()); - starting_vertex_labels = std::nullopt; - { rmm::device_uvector renumbered_and_sorted_edgelist_srcs(org_edgelist_srcs.size(), handle.get_stream()); @@ -548,11 +574,9 @@ class Tests_SamplingPostProcessing std::optional> renumbered_and_sorted_edgelist_edge_types{ std::nullopt}; auto renumbered_and_sorted_edgelist_hops = - org_edgelist_hops - ? std::make_optional(std::make_tuple( - rmm::device_uvector((*org_edgelist_hops).size(), handle.get_stream()), - sampling_post_processing_usecase.fanouts.size())) - : std::nullopt; + org_edgelist_hops ? std::make_optional(rmm::device_uvector( + (*org_edgelist_hops).size(), handle.get_stream())) + : std::nullopt; raft::copy(renumbered_and_sorted_edgelist_srcs.data(), org_edgelist_srcs.data(), @@ -569,7 +593,7 @@ class Tests_SamplingPostProcessing handle.get_stream()); } if (renumbered_and_sorted_edgelist_hops) { - raft::copy(std::get<0>(*renumbered_and_sorted_edgelist_hops).data(), + raft::copy((*renumbered_and_sorted_edgelist_hops).data(), (*org_edgelist_hops).data(), (*org_edgelist_hops).size(), handle.get_stream()); @@ -581,14 +605,6 @@ class Tests_SamplingPostProcessing std::optional> renumbered_and_sorted_renumber_map_label_offsets{ std::nullopt}; - { - size_t free_size{}; - size_t total_size{}; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); - std::cout << "free_size=" << free_size / (1024.0 * 1024.0 * 1024.0) - << "GB total_size=" << total_size / (1024.0 * 1024.0 * 1024.0) << "GB." - << std::endl; - } if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement hr_timer.start("Renumber and sort sampled edgelist"); @@ -602,7 +618,7 @@ class Tests_SamplingPostProcessing renumbered_and_sorted_edgelist_label_hop_offsets, renumbered_and_sorted_renumber_map, renumbered_and_sorted_renumber_map_label_offsets) = - cugraph::renumber_and_sort_sampled_edgelist( + cugraph::renumber_and_sort_sampled_edgelist( handle, std::move(renumbered_and_sorted_edgelist_srcs), std::move(renumbered_and_sorted_edgelist_dsts), @@ -610,12 +626,20 @@ class Tests_SamplingPostProcessing std::move(renumbered_and_sorted_edgelist_edge_ids), std::move(renumbered_and_sorted_edgelist_edge_types), std::move(renumbered_and_sorted_edgelist_hops), + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_post_processing_usecase.renumber_with_seeds && starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, org_edgelist_label_offsets - ? std::make_optional(std::make_tuple( - raft::device_span((*org_edgelist_label_offsets).data(), - (*org_edgelist_label_offsets).size()), - sampling_post_processing_usecase.num_labels)) + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) : std::nullopt, + sampling_post_processing_usecase.num_labels, + sampling_post_processing_usecase.fanouts.size(), sampling_post_processing_usecase.src_is_major); if (cugraph::test::g_perf) { @@ -666,6 +690,15 @@ class Tests_SamplingPostProcessing } for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { + size_t starting_vertex_start_offset = + starting_vertex_label_offsets + ? (*starting_vertex_label_offsets).element(i, handle.get_stream()) + : size_t{0}; + size_t starting_vertex_end_offset = + starting_vertex_label_offsets + ? (*starting_vertex_label_offsets).element(i + 1, handle.get_stream()) + : starting_vertices.size(); + size_t edgelist_start_offset = org_edgelist_label_offsets ? (*org_edgelist_label_offsets).element(i, handle.get_stream()) @@ -676,6 +709,10 @@ class Tests_SamplingPostProcessing : org_edgelist_srcs.size(); if (edgelist_start_offset == edgelist_end_offset) continue; + auto this_label_starting_vertices = raft::device_span( + starting_vertices.data() + starting_vertex_start_offset, + starting_vertex_end_offset - starting_vertex_start_offset); + auto this_label_org_edgelist_srcs = raft::device_span(org_edgelist_srcs.data() + edgelist_start_offset, edgelist_end_offset - edgelist_start_offset); @@ -769,12 +806,17 @@ class Tests_SamplingPostProcessing // Check the invariants in renumber_map - ASSERT_TRUE(check_renumber_map_invariants(handle, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_hops, - this_label_output_renumber_map, - sampling_post_processing_usecase.src_is_major)) + ASSERT_TRUE(check_renumber_map_invariants( + handle, + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>( + this_label_starting_vertices.data(), this_label_starting_vertices.size()) + : std::nullopt, + this_label_org_edgelist_srcs, + this_label_org_edgelist_dsts, + this_label_org_edgelist_hops, + this_label_output_renumber_map, + sampling_post_processing_usecase.src_is_major)) << "Renumbered and sorted output renumber map violates invariants."; } } @@ -794,11 +836,9 @@ class Tests_SamplingPostProcessing std::optional> renumbered_and_compressed_edgelist_edge_types{ std::nullopt}; auto renumbered_and_compressed_edgelist_hops = - org_edgelist_hops - ? std::make_optional(std::make_tuple( - rmm::device_uvector((*org_edgelist_hops).size(), handle.get_stream()), - sampling_post_processing_usecase.fanouts.size())) - : std::nullopt; + org_edgelist_hops ? std::make_optional(rmm::device_uvector( + (*org_edgelist_hops).size(), handle.get_stream())) + : std::nullopt; raft::copy(renumbered_and_compressed_edgelist_srcs.data(), org_edgelist_srcs.data(), @@ -815,7 +855,7 @@ class Tests_SamplingPostProcessing handle.get_stream()); } if (renumbered_and_compressed_edgelist_hops) { - raft::copy(std::get<0>(*renumbered_and_compressed_edgelist_hops).data(), + raft::copy((*renumbered_and_compressed_edgelist_hops).data(), (*org_edgelist_hops).data(), (*org_edgelist_hops).size(), handle.get_stream()); @@ -846,7 +886,7 @@ class Tests_SamplingPostProcessing renumbered_and_compressed_offset_label_hop_offsets, renumbered_and_compressed_renumber_map, renumbered_and_compressed_renumber_map_label_offsets) = - cugraph::renumber_and_compress_sampled_edgelist( + cugraph::renumber_and_compress_sampled_edgelist( handle, std::move(renumbered_and_compressed_edgelist_srcs), std::move(renumbered_and_compressed_edgelist_dsts), @@ -854,12 +894,20 @@ class Tests_SamplingPostProcessing std::move(renumbered_and_compressed_edgelist_edge_ids), std::move(renumbered_and_compressed_edgelist_edge_types), std::move(renumbered_and_compressed_edgelist_hops), + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_post_processing_usecase.renumber_with_seeds && starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, org_edgelist_label_offsets - ? std::make_optional(std::make_tuple( - raft::device_span((*org_edgelist_label_offsets).data(), - (*org_edgelist_label_offsets).size()), - sampling_post_processing_usecase.num_labels)) + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) : std::nullopt, + sampling_post_processing_usecase.num_labels, + sampling_post_processing_usecase.fanouts.size(), sampling_post_processing_usecase.src_is_major, sampling_post_processing_usecase.compress_per_hop, sampling_post_processing_usecase.doubly_compress); @@ -934,6 +982,15 @@ class Tests_SamplingPostProcessing } for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { + size_t starting_vertex_start_offset = + starting_vertex_label_offsets + ? (*starting_vertex_label_offsets).element(i, handle.get_stream()) + : size_t{0}; + size_t starting_vertex_end_offset = + starting_vertex_label_offsets + ? (*starting_vertex_label_offsets).element(i + 1, handle.get_stream()) + : starting_vertices.size(); + size_t edgelist_start_offset = org_edgelist_label_offsets ? (*org_edgelist_label_offsets).element(i, handle.get_stream()) @@ -944,6 +1001,10 @@ class Tests_SamplingPostProcessing : org_edgelist_srcs.size(); if (edgelist_start_offset == edgelist_end_offset) continue; + auto this_label_starting_vertices = raft::device_span( + starting_vertices.data() + starting_vertex_start_offset, + starting_vertex_end_offset - starting_vertex_start_offset); + auto this_label_org_edgelist_srcs = raft::device_span(org_edgelist_srcs.data() + edgelist_start_offset, edgelist_end_offset - edgelist_start_offset); @@ -1092,12 +1153,17 @@ class Tests_SamplingPostProcessing // Check the invariants in renumber_map - ASSERT_TRUE(check_renumber_map_invariants(handle, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_hops, - this_label_output_renumber_map, - sampling_post_processing_usecase.src_is_major)) + ASSERT_TRUE(check_renumber_map_invariants( + handle, + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>( + this_label_starting_vertices.data(), this_label_starting_vertices.size()) + : std::nullopt, + this_label_org_edgelist_srcs, + this_label_org_edgelist_dsts, + this_label_org_edgelist_hops, + this_label_output_renumber_map, + sampling_post_processing_usecase.src_is_major)) << "Renumbered and sorted output renumber map violates invariants."; } } @@ -1114,12 +1180,10 @@ class Tests_SamplingPostProcessing : std::nullopt; std::optional> sorted_edgelist_edge_ids{std::nullopt}; std::optional> sorted_edgelist_edge_types{std::nullopt}; - auto sorted_edgelist_hops = - org_edgelist_hops - ? std::make_optional(std::make_tuple( - rmm::device_uvector((*org_edgelist_hops).size(), handle.get_stream()), - sampling_post_processing_usecase.fanouts.size())) - : std::nullopt; + auto sorted_edgelist_hops = org_edgelist_hops + ? std::make_optional(rmm::device_uvector( + (*org_edgelist_hops).size(), handle.get_stream())) + : std::nullopt; raft::copy(sorted_edgelist_srcs.data(), org_edgelist_srcs.data(), @@ -1136,7 +1200,7 @@ class Tests_SamplingPostProcessing handle.get_stream()); } if (sorted_edgelist_hops) { - raft::copy(std::get<0>(*sorted_edgelist_hops).data(), + raft::copy((*sorted_edgelist_hops).data(), (*org_edgelist_hops).data(), (*org_edgelist_hops).size(), handle.get_stream()); @@ -1144,14 +1208,6 @@ class Tests_SamplingPostProcessing std::optional> sorted_edgelist_label_hop_offsets{std::nullopt}; - { - size_t free_size{}; - size_t total_size{}; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); - std::cout << "free_size=" << free_size / (1024.0 * 1024.0 * 1024.0) - << "GB total_size=" << total_size / (1024.0 * 1024.0 * 1024.0) << "GB." - << std::endl; - } if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement hr_timer.start("Sort sampled edgelist"); @@ -1163,7 +1219,7 @@ class Tests_SamplingPostProcessing sorted_edgelist_edge_ids, sorted_edgelist_edge_types, sorted_edgelist_label_hop_offsets) = - cugraph::sort_sampled_edgelist( + cugraph::sort_sampled_edgelist( handle, std::move(sorted_edgelist_srcs), std::move(sorted_edgelist_dsts), @@ -1172,11 +1228,11 @@ class Tests_SamplingPostProcessing std::move(sorted_edgelist_edge_types), std::move(sorted_edgelist_hops), org_edgelist_label_offsets - ? std::make_optional(std::make_tuple( - raft::device_span((*org_edgelist_label_offsets).data(), - (*org_edgelist_label_offsets).size()), - sampling_post_processing_usecase.num_labels)) + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) : std::nullopt, + sampling_post_processing_usecase.num_labels, + sampling_post_processing_usecase.fanouts.size(), sampling_post_processing_usecase.src_is_major); if (cugraph::test::g_perf) { @@ -1329,49 +1385,88 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( // enable correctness checks ::testing::Values( + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, true, true}, - SamplingPostProcessing_Usecase{1, 4, {5, 10, 25}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, false, false}, SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, false, true}), - ::testing::Values(cugraph::test::File_Usecase("karate.mtx"), - cugraph::test::File_Usecase("dolphins.mtx")))); + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, true, false}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/dolphins.mtx")))); INSTANTIATE_TEST_SUITE_P( rmat_small_test, @@ -1379,46 +1474,86 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( // enable correctness checks ::testing::Values( + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, false, false, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, false, false}, SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {10}, true, true, false, true, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, false, false, true, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, false, true, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, true, false, true, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, false, true, true, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, false, false, true, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, false, true, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, true, false, true, true}, - SamplingPostProcessing_Usecase{1, 16, {5, 10, 15}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {10}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, false, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, false, true, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, false, true, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{1, 16, {5, 10, 25}, true, true, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, false, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, false, true, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, true, false, false, true, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, false, false}, SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {10}, true, true, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, false, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, false, true, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, true, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, false, true, true, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, false, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, false, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, false, true, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, true, false, false, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, true, false, true, true}, - SamplingPostProcessing_Usecase{32, 16, {5, 10, 15}, true, true, true, false, true}), + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {10}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, false, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, false, true, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, false, true, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, false, true, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, false, false}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, false, true}, + SamplingPostProcessing_Usecase{32, 16, {5, 10, 25}, true, true, true, true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); INSTANTIATE_TEST_SUITE_P( @@ -1427,46 +1562,87 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( // enable correctness checks ::testing::Values( - SamplingPostProcessing_Usecase{1, 64, {10}, false, false, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, false, false, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, false, true, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, false, true, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, true, false, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, true, false, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, true, true, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {10}, true, true, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, true, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, true, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, true, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, false, false, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, false, true, false}, - SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, false, false}), + SamplingPostProcessing_Usecase{1, 64, {10}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {10}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, false, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, false, true, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, false, true, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, false, true, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {10}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{ + 256, 64, {5, 10, 15}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, true, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, true, false, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_TEST_PROGRAM_MAIN()