diff --git a/cpp/src/generators/erdos_renyi_generator.cu b/cpp/src/generators/erdos_renyi_generator.cu index db4de177331..97b944b0a55 100644 --- a/cpp/src/generators/erdos_renyi_generator.cu +++ b/cpp/src/generators/erdos_renyi_generator.cu @@ -23,10 +23,9 @@ #include #include #include -#include +#include #include #include -#include #include namespace cugraph { @@ -42,45 +41,38 @@ generate_erdos_renyi_graph_edgelist_gnp(raft::handle_t const& handle, CUGRAPH_EXPECTS(num_vertices < std::numeric_limits::max(), "Implementation cannot support specified value"); - auto random_iterator = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - cuda::proclaim_return_type([seed] __device__(size_t index) { - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution dist(0.0, 1.0); - rng.discard(index); - return dist(rng); - })); + size_t max_num_edges = static_cast(num_vertices) * num_vertices; - size_t count = thrust::count_if(handle.get_thrust_policy(), - random_iterator, - random_iterator + num_vertices * num_vertices, - [p] __device__(float prob) { return prob < p; }); - - rmm::device_uvector indices_v(count, handle.get_stream()); + auto generate_random_value = cuda::proclaim_return_type([seed] __device__(size_t index) { + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution dist(0.0, 1.0); + rng.discard(index); + return dist(rng); + }); - thrust::copy_if(handle.get_thrust_policy(), - random_iterator, - random_iterator + num_vertices * num_vertices, - indices_v.begin(), - [p] __device__(float prob) { return prob < p; }); + size_t count = thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(max_num_edges), + [generate_random_value, p] __device__(size_t index) { + return generate_random_value(index) < p; + }); rmm::device_uvector src_v(count, handle.get_stream()); rmm::device_uvector dst_v(count, handle.get_stream()); - thrust::transform(handle.get_thrust_policy(), - indices_v.begin(), - indices_v.end(), - thrust::make_zip_iterator(thrust::make_tuple(src_v.begin(), src_v.end())), + thrust::copy_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(max_num_edges), + thrust::make_transform_output_iterator( + thrust::make_zip_iterator(src_v.begin(), dst_v.begin()), cuda::proclaim_return_type>( [num_vertices] __device__(size_t index) { - size_t src = index / num_vertices; - size_t dst = index % num_vertices; - - return thrust::make_tuple(static_cast(src), - static_cast(dst)); - })); - - handle.sync_stream(); + return thrust::make_tuple(static_cast(index / num_vertices), + static_cast(index % num_vertices)); + })), + [generate_random_value, p] __device__(size_t index) { + return generate_random_value(index) < p; + }); return std::make_tuple(std::move(src_v), std::move(dst_v)); } diff --git a/cpp/tests/generators/erdos_renyi_test.cpp b/cpp/tests/generators/erdos_renyi_test.cpp index 348799e041a..1fcda7152b6 100644 --- a/cpp/tests/generators/erdos_renyi_test.cpp +++ b/cpp/tests/generators/erdos_renyi_test.cpp @@ -87,6 +87,7 @@ void er_test(size_t num_vertices, float p) TEST_F(GenerateErdosRenyiTest, ERTest) { er_test(size_t{10}, float{0.1}); + er_test(size_t{10}, float{0.5}); er_test(size_t{20}, float{0.1}); er_test(size_t{50}, float{0.1}); er_test(size_t{10000}, float{0.1});