Skip to content

Commit

Permalink
fix HITS convergence error
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Dec 5, 2023
1 parent 20145b4 commit 61ce3d0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
3 changes: 2 additions & 1 deletion cpp/src/link_analysis/hits_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ std::tuple<result_t, size_t> hits(raft::handle_t const& handle,
if (num_vertices == 0) { return std::make_tuple(diff_sum, final_iteration_count); }

CUGRAPH_EXPECTS(epsilon >= 0.0, "Invalid input argument: epsilon should be non-negative.");
auto tolerance = static_cast<result_t>(graph_view.number_of_vertices()) * epsilon;

// Check validity of initial guess if supplied
if (has_initial_hubs_guess && do_expensive_check) {
Expand Down Expand Up @@ -171,7 +172,7 @@ std::tuple<result_t, size_t> hits(raft::handle_t const& handle,
std::swap(prev_hubs, curr_hubs);
iter++;

if (diff_sum < epsilon) {
if (diff_sum < tolerance) {
break;
} else if (iter >= max_iterations) {
CUGRAPH_FAIL("HITS failed to converge.");
Expand Down
28 changes: 16 additions & 12 deletions cpp/tests/link_analysis/hits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ std::tuple<std::vector<result_t>, std::vector<result_t>, double, size_t> hits_re
size_t max_iterations,
std::optional<result_t const*> starting_hub_values,
bool normalized,
double tolerance)
double epsilon)
{
CUGRAPH_EXPECTS(num_vertices > 1, "number of vertices expected to be non-zero");
auto tolerance = static_cast<result_t>(num_vertices) * epsilon;

std::vector<result_t> prev_hubs(num_vertices, result_t{1.0} / num_vertices);
std::vector<result_t> prev_authorities(num_vertices, result_t{1.0} / num_vertices);
std::vector<result_t> curr_hubs(num_vertices);
Expand Down Expand Up @@ -127,8 +129,8 @@ std::tuple<std::vector<result_t>, std::vector<result_t>, double, size_t> hits_re
}

struct Hits_Usecase {
bool check_correctness{true};
bool check_initial_input{false};
bool check_correctness{true};
};

template <typename input_usecase_t>
Expand Down Expand Up @@ -175,8 +177,8 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
// 3. run hits

auto graph_view = graph.view();
auto maximum_iterations = 500;
weight_t tolerance = 1e-5;
auto maximum_iterations = 200;
weight_t epsilon = 1e-7;
rmm::device_uvector<weight_t> d_hubs(graph_view.local_vertex_partition_range_size(),
handle.get_stream());

Expand All @@ -201,7 +203,7 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
graph_view,
d_hubs.data(),
d_authorities.data(),
tolerance,
epsilon,
maximum_iterations,
hits_usecase.check_initial_input,
true,
Expand Down Expand Up @@ -232,7 +234,7 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
(hits_usecase.check_initial_input) ? std::make_optional(initial_random_hubs.data())
: std::nullopt,
true,
tolerance);
epsilon);

std::vector<weight_t> h_cugraph_hits{};
if (renumber) {
Expand All @@ -246,8 +248,7 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
handle.sync_stream();
auto threshold_ratio = 1e-3;
auto threshold_magnitude =
(1.0 / static_cast<weight_t>(graph_view.number_of_vertices())) *
threshold_ratio; // skip comparison for low hits vertices (lowly ranked vertices)
1e-6; // skip comparison for low hits vertices (lowly ranked vertices)
auto nearly_equal = [threshold_ratio, threshold_magnitude](auto lhs, auto rhs) {
return std::abs(lhs - rhs) <=
std::max(std::max(lhs, rhs) * threshold_ratio, threshold_magnitude);
Expand Down Expand Up @@ -294,14 +295,17 @@ INSTANTIATE_TEST_SUITE_P(
Tests_Hits_File,
::testing::Combine(
// enable correctness checks
::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}),
::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}),
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"),
cugraph::test::File_Usecase("test/datasets/web-Google.mtx"),
cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"),
cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"),
cugraph::test::File_Usecase("test/datasets/dolphins.mtx"))));

INSTANTIATE_TEST_SUITE_P(rmat_small_test,
Tests_Hits_Rmat,
// enable correctness checks
::testing::Combine(::testing::Values(Hits_Usecase{true, false},
::testing::Combine(::testing::Values(Hits_Usecase{false, true},
Hits_Usecase{true, true}),
::testing::Values(cugraph::test::Rmat_Usecase(
10, 16, 0.57, 0.19, 0.19, 0, false, false))));
Expand All @@ -315,7 +319,7 @@ INSTANTIATE_TEST_SUITE_P(
Tests_Hits_File,
::testing::Combine(
// disable correctness checks
::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{false, true}),
::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}),
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))));

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -327,7 +331,7 @@ INSTANTIATE_TEST_SUITE_P(
Tests_Hits_Rmat,
// disable correctness checks for large graphs
::testing::Combine(
::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{false, true}),
::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}),
::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false))));

CUGRAPH_TEST_PROGRAM_MAIN()
18 changes: 8 additions & 10 deletions cpp/tests/link_analysis/mg_hits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
#include <gtest/gtest.h>

struct Hits_Usecase {
bool check_correctness{true};
bool check_initial_input{false};
bool check_correctness{true};
};

template <typename input_usecase_t>
Expand Down Expand Up @@ -81,7 +81,7 @@ class Tests_MGHits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, in
auto mg_graph_view = mg_graph.view();

auto maximum_iterations = 200;
weight_t tolerance = 1e-8;
weight_t epsilon = 1e-7;
rmm::device_uvector<weight_t> d_mg_hubs(mg_graph_view.local_vertex_partition_range_size(),
handle_->get_stream());

Expand Down Expand Up @@ -110,7 +110,7 @@ class Tests_MGHits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, in
mg_graph_view,
d_mg_hubs.data(),
d_mg_authorities.data(),
tolerance,
epsilon,
maximum_iterations,
hits_usecase.check_initial_input,
true,
Expand Down Expand Up @@ -205,7 +205,7 @@ class Tests_MGHits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, in
sg_graph_view,
d_sg_hubs.data(),
d_sg_authorities.data(),
tolerance,
epsilon,
maximum_iterations,
hits_usecase.check_initial_input,
true,
Expand All @@ -218,9 +218,7 @@ class Tests_MGHits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, in

auto threshold_ratio = 1e-3;
auto threshold_magnitude =
(1.0 / static_cast<result_t>(mg_graph_view.number_of_vertices())) *
threshold_ratio; // skip comparison for low Hits verties (lowly ranked
// vertices)
1e-6; // skip comparison for low Hits verties (lowly ranked vertices)
auto nearly_equal = [threshold_ratio, threshold_magnitude](auto lhs, auto rhs) {
return std::abs(lhs - rhs) <
std::max(std::max(lhs, rhs) * threshold_ratio, threshold_magnitude);
Expand Down Expand Up @@ -274,7 +272,7 @@ INSTANTIATE_TEST_SUITE_P(
Tests_MGHits_File,
::testing::Combine(
// enable correctness checks
::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}),
::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}),
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"),
cugraph::test::File_Usecase("test/datasets/web-Google.mtx"),
cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"),
Expand All @@ -285,7 +283,7 @@ INSTANTIATE_TEST_SUITE_P(
Tests_MGHits_Rmat,
::testing::Combine(
// enable correctness checks
::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}),
::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}),
::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false))));

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -297,7 +295,7 @@ INSTANTIATE_TEST_SUITE_P(
Tests_MGHits_Rmat,
::testing::Combine(
// disable correctness checks for large graphs
::testing::Values(Hits_Usecase{false, false}),
::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}),
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false))));

CUGRAPH_MG_TEST_PROGRAM_MAIN()

0 comments on commit 61ce3d0

Please sign in to comment.