From 78e2cf51d00fea05f90851da69bb723d1bb631b8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 10 Jan 2025 11:04:33 +0200 Subject: [PATCH] routing+lnrpc: thread context through And remove all the context.TODO()s from the previous commit. --- lnrpc/routerrpc/router_backend.go | 5 +- lnrpc/routerrpc/router_backend_test.go | 4 +- lnrpc/routerrpc/router_server.go | 12 +-- routing/bandwidth.go | 10 +- routing/bandwidth_test.go | 6 +- routing/integrated_routing_context_test.go | 7 +- routing/mock_test.go | 13 +-- routing/pathfind.go | 12 +-- routing/pathfind_test.go | 6 +- routing/payment_lifecycle.go | 6 +- routing/payment_lifecycle_test.go | 20 ++-- routing/payment_session.go | 15 +-- routing/payment_session_source.go | 8 +- routing/payment_session_test.go | 10 +- routing/router.go | 25 +++-- routing/router_test.go | 107 +++++++++++---------- rpcserver.go | 2 +- 17 files changed, 146 insertions(+), 122 deletions(-) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 7d73681094..bac50b8a93 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -60,7 +60,8 @@ type RouterBackend struct { // FindRoute is a closure that abstracts away how we locate/query for // routes. - FindRoute func(*routing.RouteRequest) (*route.Route, float64, error) + FindRoute func(context.Context, *routing.RouteRequest) (*route.Route, + float64, error) MissionControl MissionControl @@ -169,7 +170,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // Query the channel router for a possible path to the destination that // can carry `in.Amt` satoshis _including_ the total fee required on // the route - route, successProb, err := r.FindRoute(routeReq) + route, successProb, err := r.FindRoute(ctx, routeReq) if err != nil { return nil, err } diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 877a3cc171..cb66d1925f 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, } } - findRoute := func(req *routing.RouteRequest) (*route.Route, float64, - error) { + findRoute := func(_ context.Context, req *routing.RouteRequest) ( + *route.Route, float64, error) { if int64(req.Amount) != amtSat*1000 { t.Fatal("unexpected amount") diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 9499fa25a3..1de8e3ad01 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -426,7 +426,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context, return nil, errors.New("amount must be greater than 0") default: - return s.probeDestination(req.Dest, req.AmtSat) + return s.probeDestination(ctx, req.Dest, req.AmtSat) } case isProbeInvoice: @@ -440,8 +440,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context, // probeDestination estimates fees along a route to a destination based on the // contents of the local graph. -func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, - error) { +func (s *Server) probeDestination(ctx context.Context, dest []byte, + amtSat int64) (*RouteFeeResponse, error) { destNode, err := route.NewVertexFromBytes(dest) if err != nil { @@ -469,7 +469,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse, return nil, err } - route, _, err := s.cfg.Router.FindRoute(routeReq) + route, _, err := s.cfg.Router.FindRoute(ctx, routeReq) if err != nil { return nil, err } @@ -1429,7 +1429,7 @@ func (s *Server) trackPaymentStream(context context.Context, } // BuildRoute builds a route from a list of hop addresses. -func (s *Server) BuildRoute(_ context.Context, +func (s *Server) BuildRoute(ctx context.Context, req *BuildRouteRequest) (*BuildRouteResponse, error) { if len(req.HopPubkeys) == 0 { @@ -1490,7 +1490,7 @@ func (s *Server) BuildRoute(_ context.Context, // Build the route and return it to the caller. route, err := s.cfg.Router.BuildRoute( - amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, + ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr, firstHopBlob, ) if err != nil { diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 1bd8462638..c346855ee8 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -50,12 +50,10 @@ type bandwidthManager struct { // hints for the edges we directly have open ourselves. Obtaining these hints // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. -func newBandwidthManager(graph Graph, sourceNode route.Vertex, - linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], - ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, - error) { - - ctx := context.TODO() +func newBandwidthManager(ctx context.Context, graph Graph, + sourceNode route.Vertex, linkQuery getLinkQuery, + firstHopBlob fn.Option[tlv.Blob], + ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, error) { manager := &bandwidthManager{ getLink: linkQuery, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index b31d0095ac..2f7465cbea 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "github.com/btcsuite/btcd/btcutil" @@ -15,6 +16,9 @@ import ( // TestBandwidthManager tests getting of bandwidth hints from a bandwidth // manager. func TestBandwidthManager(t *testing.T) { + t.Parallel() + ctx := context.Background() + var ( chan1ID uint64 = 101 chan2ID uint64 = 102 @@ -116,7 +120,7 @@ func TestBandwidthManager(t *testing.T) { ) m, err := newBandwidthManager( - g, sourceNode.pubkey, testCase.linkQuery, + ctx, g, sourceNode.pubkey, testCase.linkQuery, fn.None[[]byte](), fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 816da87ac2..f5b73f8a45 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -174,7 +174,9 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, ) require.NoError(c.t, err) - getBandwidthHints := func(_ Graph) (bandwidthHints, error) { + getBandwidthHints := func(_ context.Context, _ Graph) (bandwidthHints, + error) { + // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { @@ -236,7 +238,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, // Find a route. route, err := session.RequestRoute( - amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, + context.Background(), amtRemaining, + lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, }, diff --git a/routing/mock_test.go b/routing/mock_test.go index 86fd765499..33f7aecd14 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "sync" @@ -168,9 +169,9 @@ type mockPaymentSessionOld struct { var _ PaymentSession = (*mockPaymentSessionOld)(nil) -func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, - _, height uint32, _ lnwire.CustomRecords) (*route.Route, - error) { +func (m *mockPaymentSessionOld) RequestRoute(_ context.Context, + _, _ lnwire.MilliSatoshi, _, _ uint32, _ lnwire.CustomRecords) ( + *route.Route, error) { if m.release != nil { m.release <- struct{}{} @@ -695,12 +696,12 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) -func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (m *mockPaymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { args := m.Called( - maxAmt, feeLimit, activeShards, height, firstHopCustomRecords, + ctx, maxAmt, feeLimit, activeShards, height, firstHopCustomRecords, ) // Type assertion on nil will fail, so we check and return here. diff --git a/routing/pathfind.go b/routing/pathfind.go index 55836e9a21..a75e8273bc 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -51,7 +51,7 @@ const ( ) // pathFinder defines the interface of a path finding algorithm. -type pathFinder = func(g *graphParams, r *RestrictParams, +type pathFinder = func(ctx context.Context, g *graphParams, r *RestrictParams, cfg *PathFindingConfig, self, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( []*unifiedEdge, float64, error) @@ -576,12 +576,10 @@ func getOutgoingBalance(ctx context.Context, node route.Vertex, // source. This is to properly accumulate fees that need to be paid along the // path and accurately check the amount to forward at every node against the // available bandwidth. -func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - self, source, target route.Vertex, amt lnwire.MilliSatoshi, - timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, - error) { - - ctx := context.TODO() +func findPath(ctx context.Context, g *graphParams, r *RestrictParams, + cfg *PathFindingConfig, self, source, target route.Vertex, + amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( + []*unifiedEdge, float64, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 13d26aa223..5775b8bb90 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2326,7 +2326,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find route") // Now we'll examine the route returned for correctness. @@ -2353,7 +2353,7 @@ func TestPathFindSpecExample(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err = ctx.router.FindRoute(req) + route, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find routes") // The route should be two hops. @@ -3236,7 +3236,7 @@ func dbFindPath(graph *graphdb.ChannelGraph, }() route, _, err := findPath( - &graphParams{ + context.Background(), &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, graph: graphSess, diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 180d38a631..8064473e72 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -270,7 +270,7 @@ lifecycle: } // Now request a route to be used to create our HTLC attempt. - rt, err := p.requestRoute(ps) + rt, err := p.requestRoute(ctx, ps) if err != nil { return exitWithErr(err) } @@ -377,14 +377,14 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // requestRoute is responsible for finding a route to be used to create an HTLC // attempt. -func (p *paymentLifecycle) requestRoute( +func (p *paymentLifecycle) requestRoute(ctx context.Context, ps *channeldb.MPPaymentState) (*route.Route, error) { remainingFees := p.calcFeeBudget(ps.FeesPaid) // Query our payment session to construct a route. rt, err := p.paySession.RequestRoute( - ps.RemainingAmt, remainingFees, + ctx, ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), p.firstHopCustomRecords, ) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 72aa631419..880a1faeb0 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -357,6 +357,7 @@ func TestCheckTimeoutOnRouterQuit(t *testing.T) { // a route. func TestRequestRouteSucceed(t *testing.T) { t.Parallel() + ctx := context.Background() p := createTestPaymentLifecycle() @@ -379,11 +380,11 @@ func TestRequestRouteSucceed(t *testing.T) { // Mock the paySession's `RequestRoute` method to return no error. paySession.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(dummyRoute, nil) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(ctx, ps) require.NoError(t, err, "expect no error") require.Equal(t, dummyRoute, result, "returned route not matched") @@ -395,6 +396,7 @@ func TestRequestRouteSucceed(t *testing.T) { // successfully handle a critical error returned from payment session. func TestRequestRouteHandleCriticalErr(t *testing.T) { t.Parallel() + ctx := context.Background() p := createTestPaymentLifecycle() @@ -416,11 +418,11 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil, errDummy) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(ctx, ps) // Expect an error is returned since it's critical. require.ErrorIs(t, err, errDummy, "error not matched") @@ -434,6 +436,7 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { // handle the `noRouteError` returned from payment session. func TestRequestRouteHandleNoRouteErr(t *testing.T) { t.Parallel() + ctx := context.Background() // Create a paymentLifecycle with mockers. p, m := newTestPaymentLifecycle(t) @@ -451,7 +454,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // Mock the paySession's `RequestRoute` method to return a NoRouteErr // type. m.paySession.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) @@ -460,7 +463,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { p.identifier, channeldb.FailureReasonNoRoute, ).Return(nil).Once() - result, err := p.requestRoute(ps) + result, err := p.requestRoute(ctx, ps) // Expect no error is returned since it's not critical. require.NoError(t, err, "expected no error") @@ -471,6 +474,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // error from calling `FailPayment`. func TestRequestRouteFailPaymentError(t *testing.T) { t.Parallel() + ctx := context.Background() p := createTestPaymentLifecycle() @@ -499,11 +503,11 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(nil, errNoTlvPayload) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(ctx, ps) // Expect an error is returned. require.ErrorIs(t, err, errDummy, "error not matched") diff --git a/routing/payment_session.go b/routing/payment_session.go index 0afdf822fb..630752eef1 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "github.com/btcsuite/btcd/btcec/v2" @@ -137,7 +138,7 @@ type PaymentSession interface { // // A noRouteError is returned if a non-critical error is encountered // during path finding. - RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, + RequestRoute(ctx context.Context, maxAmt, feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) @@ -169,7 +170,7 @@ type paymentSession struct { additionalEdges map[route.Vertex][]AdditionalEdge - getBandwidthHints func(Graph) (bandwidthHints, error) + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error) payment *LightningPayment @@ -197,7 +198,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, - getBandwidthHints func(Graph) (bandwidthHints, error), + getBandwidthHints func(context.Context, Graph) (bandwidthHints, error), graphSessFactory GraphSessionFactory, missionControl MissionControlQuerier, pathFindingConfig PathFindingConfig) (*paymentSession, error) { @@ -244,8 +245,8 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, // // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. -func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32, +func (p *paymentSession) RequestRoute(ctx context.Context, maxAmt, + feeLimit lnwire.MilliSatoshi, activeShards, height uint32, firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { if p.empty { @@ -308,7 +309,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // don't have enough bandwidth to carry the payment. New // bandwidth hints are queried for every new path finding // attempt, because concurrent payments may change balances. - bandwidthHints, err := p.getBandwidthHints(graph) + bandwidthHints, err := p.getBandwidthHints(ctx, graph) if err != nil { // Close routing graph session. if graphErr := closeGraph(); graphErr != nil { @@ -323,7 +324,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // Find a route for the current amount. path, _, err := p.pathFinder( - &graphParams{ + ctx, &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, graph: graph, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 5e4eb23d7f..dc2cab399e 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -1,6 +1,8 @@ package routing import ( + "context" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -56,9 +58,11 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment, trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) { - getBandwidthHints := func(graph Graph) (bandwidthHints, error) { + getBandwidthHints := func(ctx context.Context, graph Graph) ( + bandwidthHints, error) { + return newBandwidthManager( - graph, m.SourceNode.PubKeyBytes, m.GetLink, + ctx, graph, m.SourceNode.PubKeyBytes, m.GetLink, firstHopBlob, trafficShaper, ) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 278e090440..069b867c89 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "testing" "time" @@ -115,7 +116,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -193,7 +194,7 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, route.Vertex{}, - func(Graph) (bandwidthHints, error) { + func(context.Context, Graph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, newMockGraphSessionFactory(&sessionGraph{}), @@ -205,8 +206,8 @@ func TestRequestRoute(t *testing.T) { } // Override pathfinder with a mock. - session.pathFinder = func(_ *graphParams, r *RestrictParams, - _ *PathFindingConfig, _, _, _ route.Vertex, + session.pathFinder = func(_ context.Context, _ *graphParams, + r *RestrictParams, _ *PathFindingConfig, _, _, _ route.Vertex, _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, float64, error) { @@ -233,6 +234,7 @@ func TestRequestRoute(t *testing.T) { } route, err := session.RequestRoute( + context.Background(), payment.Amount, payment.FeeLimit, 0, height, lnwire.CustomRecords{ lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3}, diff --git a/routing/router.go b/routing/router.go index 93c3e1b594..69ea4204c0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -515,8 +515,8 @@ func getTargetNode(target *route.Vertex, // FindRoute attempts to query the ChannelRouter for the optimum path to a // particular target destination to which it is able to send `amt` after // factoring in channel capacities and cumulative fees along the route. -func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, - error) { +func (r *ChannelRouter) FindRoute(ctx context.Context, req *RouteRequest) ( + *route.Route, float64, error) { log.Debugf("Searching for path to %v, sending %v", req.Target, req.Amount) @@ -524,7 +524,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, fn.None[tlv.Blob](), r.cfg.TrafficShaper, ) if err != nil { @@ -549,7 +549,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, } path, probability, err := findPath( - &graphParams{ + ctx, &graphParams{ additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, graph: r.cfg.RoutingGraph, @@ -616,12 +616,11 @@ type BlindedPathRestrictions struct { // FindBlindedPaths finds a selection of paths to the destination node that can // be used in blinded payment paths. -func (r *ChannelRouter) FindBlindedPaths(destination route.Vertex, - amt lnwire.MilliSatoshi, probabilitySrc probabilitySource, +func (r *ChannelRouter) FindBlindedPaths(ctx context.Context, + destination route.Vertex, amt lnwire.MilliSatoshi, + probabilitySrc probabilitySource, restrictions *BlindedPathRestrictions) ([]*route.Route, error) { - ctx := context.TODO() - // First, find a set of candidate paths given the destination node and // path length restrictions. paths, err := findBlindedPaths( @@ -1368,12 +1367,12 @@ func (e ErrNoChannel) Error() string { // BuildRoute returns a fully specified route based on a list of pubkeys. If // amount is nil, the minimum routable amount is used. To force a specific // outgoing channel, use the outgoingChan parameter. -func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], - hops []route.Vertex, outgoingChan *uint64, finalCltvDelta int32, +func (r *ChannelRouter) BuildRoute(ctx context.Context, + amt fn.Option[lnwire.MilliSatoshi], hops []route.Vertex, + outgoingChan *uint64, finalCltvDelta int32, payAddr fn.Option[[32]byte], firstHopBlob fn.Option[[]byte]) ( *route.Route, error) { - ctx := context.TODO() log.Tracef("BuildRoute called: hopsCount=%v, amt=%v", len(hops), amt) var outgoingChans map[uint64]struct{} @@ -1386,8 +1385,8 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, firstHopBlob, - r.cfg.TrafficShaper, + ctx, r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + firstHopBlob, r.cfg.TrafficShaper, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 22c9d14e50..34b1c185bf 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "fmt" "image/color" "math" @@ -271,7 +272,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) { ) require.NoError(t, err, "invalid route request") - route, _, err := ctx.router.FindRoute(req) + route, _, err := ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") require.Falsef(t, @@ -1530,6 +1531,9 @@ func TestSendToRouteMaxHops(t *testing.T) { // TestBuildRoute tests whether correct routes are built. func TestBuildRoute(t *testing.T) { + t.Parallel() + ctx := context.Background() + // Setup a three node network. chanCapSat := btcutil.Amount(100000) paymentAddrFeatures := lnwire.NewFeatureVector( @@ -1638,7 +1642,9 @@ func TestBuildRoute(t *testing.T) { const startingBlockHeight = 101 - ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) + tCtx := createTestCtxFromGraphInstance( + t, startingBlockHeight, testGraph, + ) checkHops := func(rt *route.Route, expected []uint64, payAddr [32]byte) { @@ -1664,27 +1670,28 @@ func TestBuildRoute(t *testing.T) { // Test that we can't build a route when no hops are given. hops = []route.Vertex{} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + _, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) // Create hop list for an unknown destination. - hops := []route.Vertex{ctx.aliases["b"], ctx.aliases["y"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops := []route.Vertex{tCtx.aliases["b"], tCtx.aliases["y"]} + _, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) noChanErr := ErrNoChannel{} require.ErrorAs(t, err, &noChanErr) require.Equal(t, 1, noChanErr.position) // Create hop list from the route node pubkeys. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["c"]} + hops = []route.Vertex{tCtx.aliases["b"], tCtx.aliases["c"]} amt := lnwire.NewMSatFromSatoshis(100) // Build the route for the given amount. - rt, err := ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err := tCtx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1696,8 +1703,8 @@ func TestBuildRoute(t *testing.T) { require.Equal(t, lnwire.MilliSatoshi(106000), rt.TotalAmount) // Build the route for the minimum amount. - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + rt, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1713,9 +1720,10 @@ func TestBuildRoute(t *testing.T) { // Test a route that contains incompatible channel htlc constraints. // There is no amount that can pass through both channel 5 and 4. - hops = []route.Vertex{ctx.aliases["e"], ctx.aliases["c"]} - _, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.None[[32]byte](), fn.None[[]byte](), + hops = []route.Vertex{tCtx.aliases["e"], tCtx.aliases["c"]} + _, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.None[[32]byte](), + fn.None[[]byte](), ) require.Error(t, err) noChanErr = ErrNoChannel{} @@ -1733,9 +1741,9 @@ func TestBuildRoute(t *testing.T) { // could me more applicable, which is why we don't get back the highest // amount that could be delivered to the receiver of 21819 msat, using // policy of channel 3. - hops = []route.Vertex{ctx.aliases["b"], ctx.aliases["z"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tCtx.aliases["b"], tCtx.aliases["z"]} + rt, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{1, 8}, payAddr) @@ -1746,10 +1754,10 @@ func TestBuildRoute(t *testing.T) { // inbound fees. We expect a similar amount as for the above case of // b->c, but reduced by the inbound discount on the channel a->d. // We get 106000 - 1000 (base in) - 0.001 * 106000 (rate in) = 104894. - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} + hops = []route.Vertex{tCtx.aliases["d"], tCtx.aliases["f"]} amt = lnwire.NewMSatFromSatoshis(100) - rt, err = ctx.router.BuildRoute( - fn.Some(amt), hops, nil, 40, fn.Some(payAddr), + rt, err = tCtx.router.BuildRoute( + ctx, fn.Some(amt), hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) @@ -1764,9 +1772,9 @@ func TestBuildRoute(t *testing.T) { // due to rounding. This would not be compatible with the sender amount // of 20179 msat, which results in underpayment of 1 msat in fee. There // is a third pass through newRoute in which this gets corrected to end - hops = []route.Vertex{ctx.aliases["d"], ctx.aliases["f"]} - rt, err = ctx.router.BuildRoute( - noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), + hops = []route.Vertex{tCtx.aliases["d"], tCtx.aliases["f"]} + rt, err = tCtx.router.BuildRoute( + ctx, noAmt, hops, nil, 40, fn.Some(payAddr), fn.None[[]byte](), ) require.NoError(t, err) checkHops(rt, []uint64{9, 10}, payAddr) @@ -2904,7 +2912,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { paymentAmt, 0, noRestrictions, nil, nil, nil, MinCLTVDelta, ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") // Now check that we can update the node info for the partial node @@ -2943,7 +2951,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { ) require.NoError(t, err, "invalid route request") - _, _, err = ctx.router.FindRoute(req) + _, _, err = ctx.router.FindRoute(context.Background(), req) require.NoError(t, err, "unable to find any routes") copy1, err := ctx.graph.FetchLightningNode(pub1) @@ -3081,6 +3089,7 @@ func createChannelEdge(bitcoinKey1, bitcoinKey2 []byte, // paths with the highest success probability. func TestFindBlindedPathsWithMC(t *testing.T) { t.Parallel() + ctx := context.Background() rbFeatureBits := []lnwire.FeatureBit{ lnwire.RouteBlindingOptional, @@ -3138,15 +3147,15 @@ func TestFindBlindedPathsWithMC(t *testing.T) { ) require.NoError(t, err) - ctx := createTestCtxFromGraphInstance(t, 101, testGraph) + tCtx := createTestCtxFromGraphInstance(t, 101, testGraph) var ( - alice = ctx.aliases["alice"] - bob = ctx.aliases["bob"] - charlie = ctx.aliases["charlie"] - dave = ctx.aliases["dave"] - eve = ctx.aliases["eve"] - frank = ctx.aliases["frank"] + alice = tCtx.aliases["alice"] + bob = tCtx.aliases["bob"] + charlie = tCtx.aliases["charlie"] + dave = tCtx.aliases["dave"] + eve = tCtx.aliases["eve"] + frank = tCtx.aliases["frank"] ) // Create a mission control store which initially sets the success @@ -3173,8 +3182,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // All the probabilities are set to 1. So if we restrict the path length // to 2 and allow a max of 3 routes, then we expect three paths here. - routes, err := ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err := tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3191,12 +3200,12 @@ func TestFindBlindedPathsWithMC(t *testing.T) { var actualPaths []string for _, path := range paths { label := getAliasFromPubKey( - path.SourcePubKey, ctx.aliases, + path.SourcePubKey, tCtx.aliases, ) + "," for _, hop := range path.Hops { label += getAliasFromPubKey( - hop.PubKeyBytes, ctx.aliases, + hop.PubKeyBytes, tCtx.aliases, ) + "," } @@ -3218,8 +3227,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // 3) A -> F -> D missionControl[bob][dave] = 0.5 missionControl[frank][dave] = 0.25 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3235,8 +3244,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Just to show that the above result was not a fluke, let's change // the C->D link to be the weak one. missionControl[charlie][dave] = 0.125 - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3251,8 +3260,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Change the MaxNumPaths to 1 to assert that only the best route is // returned. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 1, @@ -3265,8 +3274,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test the edge case where Dave, the recipient, is also the // introduction node. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 0, NumHops: 0, MaxNumPaths: 1, @@ -3280,8 +3289,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Finally, we make one of the routes have a probability less than the // minimum. This means we expect that route not to be chosen. missionControl[charlie][dave] = DefaultMinRouteProbability - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, @@ -3295,8 +3304,8 @@ func TestFindBlindedPathsWithMC(t *testing.T) { // Test that if the user explicitly indicates that we should ignore // the Frank node during path selection, then this is done. - routes, err = ctx.router.FindBlindedPaths( - dave, 1000, probabilitySrc, &BlindedPathRestrictions{ + routes, err = tCtx.router.FindBlindedPaths( + ctx, dave, 1000, probabilitySrc, &BlindedPathRestrictions{ MinDistanceFromIntroNode: 2, NumHops: 2, MaxNumPaths: 3, diff --git a/rpcserver.go b/rpcserver.go index 8f6b302bd2..2e4f1c7fcd 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6128,7 +6128,7 @@ func (r *rpcServer) AddInvoice(ctx context.Context, []*route.Route, error) { return r.server.chanRouter.FindBlindedPaths( - r.selfNode, amt, + ctx, r.selfNode, amt, r.server.defaultMC.GetProbability, blindingRestrictions, )