Skip to content

Commit

Permalink
routing+lnrpc: thread context through
Browse files Browse the repository at this point in the history
And remove all the context.TODO()s from the previous commit.
  • Loading branch information
ellemouton committed Jan 10, 2025
1 parent 5265d09 commit 78e2cf5
Show file tree
Hide file tree
Showing 17 changed files with 146 additions and 122 deletions.
5 changes: 3 additions & 2 deletions lnrpc/routerrpc/router_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ type RouterBackend struct {

// FindRoute is a closure that abstracts away how we locate/query for
// routes.
FindRoute func(*routing.RouteRequest) (*route.Route, float64, error)
FindRoute func(context.Context, *routing.RouteRequest) (*route.Route,
float64, error)

MissionControl MissionControl

Expand Down Expand Up @@ -169,7 +170,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// Query the channel router for a possible path to the destination that
// can carry `in.Amt` satoshis _including_ the total fee required on
// the route
route, successProb, err := r.FindRoute(routeReq)
route, successProb, err := r.FindRoute(ctx, routeReq)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions lnrpc/routerrpc/router_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
}
}

findRoute := func(req *routing.RouteRequest) (*route.Route, float64,
error) {
findRoute := func(_ context.Context, req *routing.RouteRequest) (
*route.Route, float64, error) {

if int64(req.Amount) != amtSat*1000 {
t.Fatal("unexpected amount")
Expand Down
12 changes: 6 additions & 6 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context,
return nil, errors.New("amount must be greater than 0")

default:
return s.probeDestination(req.Dest, req.AmtSat)
return s.probeDestination(ctx, req.Dest, req.AmtSat)
}

case isProbeInvoice:
Expand All @@ -440,8 +440,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context,

// probeDestination estimates fees along a route to a destination based on the
// contents of the local graph.
func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
error) {
func (s *Server) probeDestination(ctx context.Context, dest []byte,
amtSat int64) (*RouteFeeResponse, error) {

destNode, err := route.NewVertexFromBytes(dest)
if err != nil {
Expand Down Expand Up @@ -469,7 +469,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
return nil, err
}

route, _, err := s.cfg.Router.FindRoute(routeReq)
route, _, err := s.cfg.Router.FindRoute(ctx, routeReq)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1429,7 +1429,7 @@ func (s *Server) trackPaymentStream(context context.Context,
}

// BuildRoute builds a route from a list of hop addresses.
func (s *Server) BuildRoute(_ context.Context,
func (s *Server) BuildRoute(ctx context.Context,
req *BuildRouteRequest) (*BuildRouteResponse, error) {

if len(req.HopPubkeys) == 0 {
Expand Down Expand Up @@ -1490,7 +1490,7 @@ func (s *Server) BuildRoute(_ context.Context,

// Build the route and return it to the caller.
route, err := s.cfg.Router.BuildRoute(
amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
firstHopBlob,
)
if err != nil {
Expand Down
10 changes: 4 additions & 6 deletions routing/bandwidth.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ type bandwidthManager struct {
// hints for the edges we directly have open ourselves. Obtaining these hints
// allows us to reduce the number of extraneous attempts as we can skip channels
// that are inactive, or just don't have enough bandwidth to carry the payment.
func newBandwidthManager(graph Graph, sourceNode route.Vertex,
linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob],
ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager,
error) {

ctx := context.TODO()
func newBandwidthManager(ctx context.Context, graph Graph,
sourceNode route.Vertex, linkQuery getLinkQuery,
firstHopBlob fn.Option[tlv.Blob],
ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, error) {

manager := &bandwidthManager{
getLink: linkQuery,
Expand Down
6 changes: 5 additions & 1 deletion routing/bandwidth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"testing"

"github.com/btcsuite/btcd/btcutil"
Expand All @@ -15,6 +16,9 @@ import (
// TestBandwidthManager tests getting of bandwidth hints from a bandwidth
// manager.
func TestBandwidthManager(t *testing.T) {
t.Parallel()
ctx := context.Background()

var (
chan1ID uint64 = 101
chan2ID uint64 = 102
Expand Down Expand Up @@ -116,7 +120,7 @@ func TestBandwidthManager(t *testing.T) {
)

m, err := newBandwidthManager(
g, sourceNode.pubkey, testCase.linkQuery,
ctx, g, sourceNode.pubkey, testCase.linkQuery,
fn.None[[]byte](),
fn.Some[htlcswitch.AuxTrafficShaper](
&mockTrafficShaper{},
Expand Down
7 changes: 5 additions & 2 deletions routing/integrated_routing_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
)
require.NoError(c.t, err)

getBandwidthHints := func(_ Graph) (bandwidthHints, error) {
getBandwidthHints := func(_ context.Context, _ Graph) (bandwidthHints,
error) {

// Create bandwidth hints based on local channel balances.
bandwidthHints := map[uint64]lnwire.MilliSatoshi{}
for _, ch := range c.graph.nodes[c.source.pubkey].channels {
Expand Down Expand Up @@ -236,7 +238,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,

// Find a route.
route, err := session.RequestRoute(
amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0,
context.Background(), amtRemaining,
lnwire.MaxMilliSatoshi, inFlightHtlcs, 0,
lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
},
Expand Down
13 changes: 7 additions & 6 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"fmt"
"sync"

Expand Down Expand Up @@ -168,9 +169,9 @@ type mockPaymentSessionOld struct {

var _ PaymentSession = (*mockPaymentSessionOld)(nil)

func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
_, height uint32, _ lnwire.CustomRecords) (*route.Route,
error) {
func (m *mockPaymentSessionOld) RequestRoute(_ context.Context,
_, _ lnwire.MilliSatoshi, _, _ uint32, _ lnwire.CustomRecords) (
*route.Route, error) {

if m.release != nil {
m.release <- struct{}{}
Expand Down Expand Up @@ -695,12 +696,12 @@ type mockPaymentSession struct {

var _ PaymentSession = (*mockPaymentSession)(nil)

func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32,
func (m *mockPaymentSession) RequestRoute(ctx context.Context, maxAmt,
feeLimit lnwire.MilliSatoshi, activeShards, height uint32,
firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) {

args := m.Called(
maxAmt, feeLimit, activeShards, height, firstHopCustomRecords,
ctx, maxAmt, feeLimit, activeShards, height, firstHopCustomRecords,

Check failure on line 704 in routing/mock_test.go

View workflow job for this annotation

GitHub Actions / lint code

the line is 83 characters long, which exceeds the maximum of 80 characters. (ll)
)

// Type assertion on nil will fail, so we check and return here.
Expand Down
12 changes: 5 additions & 7 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const (
)

// pathFinder defines the interface of a path finding algorithm.
type pathFinder = func(g *graphParams, r *RestrictParams,
type pathFinder = func(ctx context.Context, g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error)
Expand Down Expand Up @@ -576,12 +576,10 @@ func getOutgoingBalance(ctx context.Context, node route.Vertex,
// source. This is to properly accumulate fees that need to be paid along the
// path and accurately check the amount to forward at every node against the
// available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
self, source, target route.Vertex, amt lnwire.MilliSatoshi,
timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64,
error) {

ctx := context.TODO()
func findPath(ctx context.Context, g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error) {

// Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to
Expand Down
6 changes: 3 additions & 3 deletions routing/pathfind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2326,7 +2326,7 @@ func TestPathFindSpecExample(t *testing.T) {
)
require.NoError(t, err, "invalid route request")

route, _, err := ctx.router.FindRoute(req)
route, _, err := ctx.router.FindRoute(context.Background(), req)
require.NoError(t, err, "unable to find route")

// Now we'll examine the route returned for correctness.
Expand All @@ -2353,7 +2353,7 @@ func TestPathFindSpecExample(t *testing.T) {
)
require.NoError(t, err, "invalid route request")

route, _, err = ctx.router.FindRoute(req)
route, _, err = ctx.router.FindRoute(context.Background(), req)
require.NoError(t, err, "unable to find routes")

// The route should be two hops.
Expand Down Expand Up @@ -3236,7 +3236,7 @@ func dbFindPath(graph *graphdb.ChannelGraph,
}()

route, _, err := findPath(
&graphParams{
context.Background(), &graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: graphSess,
Expand Down
6 changes: 3 additions & 3 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ lifecycle:
}

// Now request a route to be used to create our HTLC attempt.
rt, err := p.requestRoute(ps)
rt, err := p.requestRoute(ctx, ps)
if err != nil {
return exitWithErr(err)
}
Expand Down Expand Up @@ -377,14 +377,14 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error {

// requestRoute is responsible for finding a route to be used to create an HTLC
// attempt.
func (p *paymentLifecycle) requestRoute(
func (p *paymentLifecycle) requestRoute(ctx context.Context,
ps *channeldb.MPPaymentState) (*route.Route, error) {

remainingFees := p.calcFeeBudget(ps.FeesPaid)

// Query our payment session to construct a route.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
ctx, ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
p.firstHopCustomRecords,
)
Expand Down
20 changes: 12 additions & 8 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ func TestCheckTimeoutOnRouterQuit(t *testing.T) {
// a route.
func TestRequestRouteSucceed(t *testing.T) {
t.Parallel()
ctx := context.Background()

p := createTestPaymentLifecycle()

Expand All @@ -379,11 +380,11 @@ func TestRequestRouteSucceed(t *testing.T) {

// Mock the paySession's `RequestRoute` method to return no error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything,
).Return(dummyRoute, nil)

result, err := p.requestRoute(ps)
result, err := p.requestRoute(ctx, ps)
require.NoError(t, err, "expect no error")
require.Equal(t, dummyRoute, result, "returned route not matched")

Expand All @@ -395,6 +396,7 @@ func TestRequestRouteSucceed(t *testing.T) {
// successfully handle a critical error returned from payment session.
func TestRequestRouteHandleCriticalErr(t *testing.T) {
t.Parallel()
ctx := context.Background()

p := createTestPaymentLifecycle()

Expand All @@ -416,11 +418,11 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) {

// Mock the paySession's `RequestRoute` method to return an error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything,
).Return(nil, errDummy)

result, err := p.requestRoute(ps)
result, err := p.requestRoute(ctx, ps)

// Expect an error is returned since it's critical.
require.ErrorIs(t, err, errDummy, "error not matched")
Expand All @@ -434,6 +436,7 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) {
// handle the `noRouteError` returned from payment session.
func TestRequestRouteHandleNoRouteErr(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Create a paymentLifecycle with mockers.
p, m := newTestPaymentLifecycle(t)
Expand All @@ -451,7 +454,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) {
// Mock the paySession's `RequestRoute` method to return a NoRouteErr
// type.
m.paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything,
).Return(nil, errNoTlvPayload)

Expand All @@ -460,7 +463,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) {
p.identifier, channeldb.FailureReasonNoRoute,
).Return(nil).Once()

result, err := p.requestRoute(ps)
result, err := p.requestRoute(ctx, ps)

// Expect no error is returned since it's not critical.
require.NoError(t, err, "expected no error")
Expand All @@ -471,6 +474,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) {
// error from calling `FailPayment`.
func TestRequestRouteFailPaymentError(t *testing.T) {
t.Parallel()
ctx := context.Background()

p := createTestPaymentLifecycle()

Expand Down Expand Up @@ -499,11 +503,11 @@ func TestRequestRouteFailPaymentError(t *testing.T) {

// Mock the paySession's `RequestRoute` method to return an error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything,
).Return(nil, errNoTlvPayload)

result, err := p.requestRoute(ps)
result, err := p.requestRoute(ctx, ps)

// Expect an error is returned.
require.ErrorIs(t, err, errDummy, "error not matched")
Expand Down
Loading

0 comments on commit 78e2cf5

Please sign in to comment.