Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context reloaded #164

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions channelnotifier/channelnotifier.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package channelnotifier

import (
"context"
"sync"

"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/subscribe"
)

Expand All @@ -18,6 +20,8 @@ type ChannelNotifier struct {
ntfnServer *subscribe.Server

chanDB *channeldb.ChannelStateDB

cancel fn.Option[context.CancelFunc]
}

// PendingOpenChannelEvent represents a new event where a new channel has
Expand Down Expand Up @@ -91,11 +95,14 @@ func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier {
}

// Start starts the ChannelNotifier and all goroutines it needs to carry out its task.
func (c *ChannelNotifier) Start() error {
func (c *ChannelNotifier) Start(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
c.cancel = fn.Some(cancel)

var err error
c.started.Do(func() {
log.Info("ChannelNotifier starting")
err = c.ntfnServer.Start()
err = c.ntfnServer.Start(ctx)
})
return err
}
Expand All @@ -107,6 +114,8 @@ func (c *ChannelNotifier) Stop() error {
log.Info("ChannelNotifier shutting down...")
defer log.Debug("ChannelNotifier shutdown complete")

c.cancel.WhenSome(func(fn context.CancelFunc) { fn() })

err = c.ntfnServer.Stop()
})
return err
Expand Down
14 changes: 12 additions & 2 deletions htlcswitch/htlcnotifier.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package htlcswitch

import (
"context"
"fmt"
"strings"
"sync"
"time"

"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -65,6 +67,8 @@ type HtlcNotifier struct {
now func() time.Time

ntfnServer *subscribe.Server

cancel fn.Option[context.CancelFunc]
}

// NewHtlcNotifier creates a new HtlcNotifier which gets htlc forwarded,
Expand All @@ -79,11 +83,15 @@ func NewHtlcNotifier(now func() time.Time) *HtlcNotifier {

// Start starts the HtlcNotifier and all goroutines it needs to consume
// events and provide subscriptions to clients.
func (h *HtlcNotifier) Start() error {
func (h *HtlcNotifier) Start(ctx context.Context) error {
var err error
h.started.Do(func() {
ctx, cancel := context.WithCancel(ctx)
h.cancel = fn.Some(cancel)

log.Info("HtlcNotifier starting")
err = h.ntfnServer.Start()

err = h.ntfnServer.Start(ctx)
})
return err
}
Expand All @@ -95,6 +103,8 @@ func (h *HtlcNotifier) Stop() error {
log.Info("HtlcNotifier shutting down...")
defer log.Debug("HtlcNotifier shutdown complete")

h.cancel.WhenSome(func(fn context.CancelFunc) { fn() })

if err = h.ntfnServer.Stop(); err != nil {
log.Warnf("error stopping htlc notifier: %v", err)
}
Expand Down
7 changes: 4 additions & 3 deletions htlcswitch/switch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3410,6 +3410,7 @@ func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
getEvents htlcNotifierEvents) {

t.Parallel()
ctx := context.Background()

// First, we'll create our traditional three hop
// network.
Expand All @@ -3427,7 +3428,7 @@ func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
// Create htlc notifiers for each server in the three hop network and
// start them.
aliceNotifier := NewHtlcNotifier(mockTime)
if err := aliceNotifier.Start(); err != nil {
if err := aliceNotifier.Start(ctx); err != nil {
t.Fatalf("could not start alice notifier")
}
t.Cleanup(func() {
Expand All @@ -3437,7 +3438,7 @@ func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
})

bobNotifier := NewHtlcNotifier(mockTime)
if err := bobNotifier.Start(); err != nil {
if err := bobNotifier.Start(ctx); err != nil {
t.Fatalf("could not start bob notifier")
}
t.Cleanup(func() {
Expand All @@ -3447,7 +3448,7 @@ func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
})

carolNotifier := NewHtlcNotifier(mockTime)
if err := carolNotifier.Start(); err != nil {
if err := carolNotifier.Start(ctx); err != nil {
t.Fatalf("could not start carol notifier")
}
t.Cleanup(func() {
Expand Down
12 changes: 7 additions & 5 deletions lnd.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
interceptorChain := rpcperms.NewInterceptorChain(
rpcsLog, cfg.NoMacaroons, cfg.RPCMiddleware.Mandatory,
)
if err := interceptorChain.Start(); err != nil {
if err := interceptorChain.Start(ctx); err != nil {
return mkErr("error starting interceptor chain: %v", err)
}
defer func() {
Expand Down Expand Up @@ -379,7 +379,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
// wildcard to prevent certificate issues when accessing the proxy
// externally.
stopProxy, err := startRestProxy(
cfg, rpcServer, restDialOpts, restListen,
ctx, cfg, rpcServer, restDialOpts, restListen,
)
if err != nil {
return mkErr("error starting REST proxy: %v", err)
Expand Down Expand Up @@ -731,10 +731,12 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg,
// case the startup of the subservers do not behave as expected.
errChan := make(chan error)
go func() {
errChan <- server.Start()
errChan <- server.Start(ctx)
}()

defer func() {
grpcServer.Stop()

err := server.Stop()
if err != nil {
ltndLog.Warnf("Stopping the server including all "+
Expand Down Expand Up @@ -921,7 +923,8 @@ func startGrpcListen(cfg *Config, grpcServer *grpc.Server,

// startRestProxy starts the given REST proxy on the listeners found in the
// config.
func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialOption,
func startRestProxy(ctx context.Context, cfg *Config, rpcServer *rpcServer,
restDialOpts []grpc.DialOption,
restListen func(net.Addr) (net.Listener, error)) (func(), error) {

// We use the first RPC listener as the destination for our REST proxy.
Expand All @@ -948,7 +951,6 @@ func startRestProxy(cfg *Config, rpcServer *rpcServer, restDialOpts []grpc.DialO
}

// Start a REST proxy for our gRPC server.
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
shutdownFuncs = append(shutdownFuncs, cancel)

Expand Down
5 changes: 4 additions & 1 deletion peer/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package peer

import (
"bytes"
"context"
crand "crypto/rand"
"encoding/binary"
"io"
Expand Down Expand Up @@ -559,6 +560,8 @@ func (m *mockMessageConn) Close() error {
// containing necessary handles and mock objects for conducting tests on peer
// functionalities.
func createTestPeer(t *testing.T) *peerTestCtx {
ctx := context.Background()

nodeKeyLocator := keychain.KeyLocator{
Family: keychain.KeyFamilyNodeKey,
}
Expand Down Expand Up @@ -655,7 +658,7 @@ func createTestPeer(t *testing.T) *peerTestCtx {

// TODO(yy): change ChannelNotifier to be an interface.
channelNotifier := channelnotifier.New(dbAlice.ChannelStateDB())
require.NoError(t, channelNotifier.Start())
require.NoError(t, channelNotifier.Start(ctx))
t.Cleanup(func() {
require.NoError(t, channelNotifier.Stop(),
"stop channel notifier failed")
Expand Down
13 changes: 11 additions & 2 deletions peernotifier/peernotifier.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package peernotifier

import (
"context"
"sync"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/subscribe"
)

Expand All @@ -14,6 +16,8 @@ type PeerNotifier struct {
stopped sync.Once

ntfnServer *subscribe.Server

cancel fn.Option[context.CancelFunc]
}

// PeerOnlineEvent represents a new event where a peer comes online.
Expand All @@ -37,12 +41,15 @@ func New() *PeerNotifier {
}

// Start starts the PeerNotifier's subscription server.
func (p *PeerNotifier) Start() error {
func (p *PeerNotifier) Start(ctx context.Context) error {
var err error

p.started.Do(func() {
ctx, cancel := context.WithCancel(ctx)
p.cancel = fn.Some(cancel)

log.Info("PeerNotifier starting")
err = p.ntfnServer.Start()
err = p.ntfnServer.Start(ctx)
})

return err
Expand All @@ -55,6 +62,8 @@ func (p *PeerNotifier) Stop() error {
log.Info("PeerNotifier shutting down...")
defer log.Debug("PeerNotifier shutdown complete")

p.cancel.WhenSome(func(fn context.CancelFunc) { fn() })

err = p.ntfnServer.Stop()
})
return err
Expand Down
13 changes: 10 additions & 3 deletions rpcperms/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/btcsuite/btclog/v2"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/monitoring"
Expand Down Expand Up @@ -189,7 +190,8 @@ type InterceptorChain struct {
// middleware crashes.
mandatoryMiddleware []string

quit chan struct{}
quit chan struct{}
cancel fn.Option[context.CancelFunc]
sync.RWMutex
}

Expand All @@ -215,10 +217,13 @@ func NewInterceptorChain(log btclog.Logger, noMacaroons bool,

// Start starts the InterceptorChain, which is needed to start the state
// subscription server it powers.
func (r *InterceptorChain) Start() error {
func (r *InterceptorChain) Start(ctx context.Context) error {
var err error
r.started.Do(func() {
err = r.ntfnServer.Start()
ctx, cancel := context.WithCancel(ctx)
r.cancel = fn.Some(cancel)

err = r.ntfnServer.Start(ctx)
})

return err
Expand All @@ -229,6 +234,8 @@ func (r *InterceptorChain) Stop() error {
var err error
r.stopped.Do(func() {
close(r.quit)
r.cancel.WhenSome(func(fn context.CancelFunc) { fn() })

err = r.ntfnServer.Stop()
})

Expand Down
20 changes: 12 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ type server struct {
// txPublisher is a publisher with fee-bumping capability.
txPublisher *sweep.TxPublisher

quit chan struct{}

wg sync.WaitGroup
quit chan struct{}
cancel fn.Option[context.CancelFunc]
wg sync.WaitGroup
}

// updatePersistentPeerAddrs subscribes to topology changes and stores
Expand Down Expand Up @@ -2046,7 +2046,10 @@ func (c cleaner) run() {
// NOTE: This function is safe for concurrent access.
//
//nolint:funlen
func (s *server) Start() error {
func (s *server) Start(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
s.cancel = fn.Some(cancel)

var startErr error

// If one sub system fails to start, the following code ensures that the
Expand All @@ -2056,7 +2059,7 @@ func (s *server) Start() error {

s.start.Do(func() {
cleanup = cleanup.add(s.customMessageServer.Stop)
if err := s.customMessageServer.Start(); err != nil {
if err := s.customMessageServer.Start(ctx); err != nil {
startErr = err
return
}
Expand Down Expand Up @@ -2113,21 +2116,21 @@ func (s *server) Start() error {
}

cleanup = cleanup.add(s.channelNotifier.Stop)
if err := s.channelNotifier.Start(); err != nil {
if err := s.channelNotifier.Start(ctx); err != nil {
startErr = err
return
}

cleanup = cleanup.add(func() error {
return s.peerNotifier.Stop()
})
if err := s.peerNotifier.Start(); err != nil {
if err := s.peerNotifier.Start(ctx); err != nil {
startErr = err
return
}

cleanup = cleanup.add(s.htlcNotifier.Stop)
if err := s.htlcNotifier.Start(); err != nil {
if err := s.htlcNotifier.Start(ctx); err != nil {
startErr = err
return
}
Expand Down Expand Up @@ -2453,6 +2456,7 @@ func (s *server) Stop() error {
atomic.StoreInt32(&s.stopping, 1)

close(s.quit)
s.cancel.WhenSome(func(fn context.CancelFunc) { fn() })

// Shutdown connMgr first to prevent conns during shutdown.
s.connMgr.Stop()
Expand Down
Loading
Loading