diff --git a/ODESolvers/src/solve.cxx b/ODESolvers/src/solve.cxx index 3d9a54d75..957691219 100644 --- a/ODESolvers/src/solve.cxx +++ b/ODESolvers/src/solve.cxx @@ -684,7 +684,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { static Timer timer_setup("ODESolvers::Solve::setup"); std::optional interval_setup(timer_setup); - statecomp_t var, rhs, pre; + statecomp_t var, rhs, pre; // call pre rhs_p and rhs_p_p (3 time levels) std::vector var_groups, rhs_groups, dep_groups; int nvars = 0; bool do_accumulate_nvars = true; @@ -708,6 +708,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { const auto num_rhs_time_levels{rhs_groupdata.mfab.size()}; + // move outside the loop if (CCTK_EQUALS(method, "RKAB4") && num_rhs_time_levels != 2) { CCTK_VERROR("Method RKAB4 requires 1 time level in the RHS group %s, " "but %lu time levels were provided ", @@ -1140,6 +1141,9 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { } else if (CCTK_EQUALS(method, "RKAB4")) { + // Initialize the prev. time levels regardelless of self starting. + // Self start later. + // Bootstrap the method with RK4 if (cctkGH->cctk_iteration <= 1) { // k1 = f(y0) @@ -1172,12 +1176,16 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { make_valid_int()); } else { - // k1 = f(y(t - h)) - // k2 = f(y0) + // k1 = f(y(t - h)) this would be rhs_p_p + // k2 = f(y0) rhs_p + // k3 = f(y0 + h (c1 k2 + c2 k1)) // k4 = f(y0 + h (c3 k3 + c4 k2 + c5 k1)) // yn = y0 + h (c6 k1 + c7 k2 + c8 k3 + c9 k4); + // kn = f(yn) or f(y(t + h)) + // calcrhs here or schedule ODESolvers_CalcRHS in poststep + constexpr auto c1{0.3736646857963324}; constexpr auto c2{0.03127973625120939}; constexpr auto c3{-0.14797683066152537}; @@ -1190,6 +1198,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { const auto old = copy_state(var, make_valid_all()); + // ki = rhs_p const auto k1 = copy_state(pre, make_valid_all()); calcupdate(1, dt, 0.0, reals<1>{1.0}, states<1>{&old}); @@ -1207,6 +1216,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) { calcupdate(4, dt, 0.0, reals<5>{1.0, dt * c6, dt * c7, dt * c8, dt * c9}, states<5>{&old, &k1, &k2, &k3, &rhs}); + // need to calc calcrhs() again here // Make sure that we store the correct RHS in the current time level statecomp_t::lincomb(rhs, 0.0, reals<1>{1.0}, states<1>{&k2}, make_valid_int());