Skip to content

Commit

Permalink
subscribe: let all methods take a context
Browse files Browse the repository at this point in the history
remove embeded context
  • Loading branch information
ellemouton committed Nov 21, 2024
1 parent 780a2da commit 94aac23
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 54 deletions.
22 changes: 11 additions & 11 deletions channelnotifier/channelnotifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ func (c *ChannelNotifier) Start(ctx context.Context) error {
c.started.Do(func() {
log.Info("ChannelNotifier starting")

_, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
c.cancel = fn.Some(cancel)
c.ctx = ctx

err = c.ntfnServer.Start()
err = c.ntfnServer.Start(ctx)
})
return err
}
Expand All @@ -133,7 +133,7 @@ func (c *ChannelNotifier) Stop() error {
// TODO(carlaKC): update to allow subscriptions to specify a block height from
// which we would like to subscribe to events.
func (c *ChannelNotifier) SubscribeChannelEvents() (*subscribe.Client, error) {
return c.ntfnServer.Subscribe()
return c.ntfnServer.Subscribe(c.ctx)
}

// NotifyPendingOpenChannelEvent notifies the channelEventNotifier goroutine
Expand All @@ -149,7 +149,7 @@ func (c *ChannelNotifier) NotifyPendingOpenChannelEvent(chanPoint wire.OutPoint,
PendingChannel: pendingChan,
}

if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send pending open channel update: %v", err)
}
}
Expand All @@ -165,7 +165,7 @@ func (c *ChannelNotifier) NotifyOpenChannelEvent(chanPoint wire.OutPoint) {

// Send the open event to all channel event subscribers.
event := OpenChannelEvent{Channel: channel}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send open channel update: %v", err)
}
}
Expand All @@ -181,7 +181,7 @@ func (c *ChannelNotifier) NotifyClosedChannelEvent(chanPoint wire.OutPoint) {

// Send the closed event to all channel event subscribers.
event := ClosedChannelEvent{CloseSummary: closeSummary}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send closed channel update: %v", err)
}
}
Expand All @@ -193,7 +193,7 @@ func (c *ChannelNotifier) NotifyFullyResolvedChannelEvent(

// Send the resolved event to all channel event subscribers.
event := FullyResolvedChannelEvent{ChannelPoint: &chanPoint}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send resolved channel update: %v", err)
}
}
Expand All @@ -202,7 +202,7 @@ func (c *ChannelNotifier) NotifyFullyResolvedChannelEvent(
// link has been added to the switch.
func (c *ChannelNotifier) NotifyActiveLinkEvent(chanPoint wire.OutPoint) {
event := ActiveLinkEvent{ChannelPoint: &chanPoint}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send active link update: %v", err)
}
}
Expand All @@ -211,7 +211,7 @@ func (c *ChannelNotifier) NotifyActiveLinkEvent(chanPoint wire.OutPoint) {
// channel is active.
func (c *ChannelNotifier) NotifyActiveChannelEvent(chanPoint wire.OutPoint) {
event := ActiveChannelEvent{ChannelPoint: &chanPoint}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send active channel update: %v", err)
}
}
Expand All @@ -220,7 +220,7 @@ func (c *ChannelNotifier) NotifyActiveChannelEvent(chanPoint wire.OutPoint) {
// link has been removed from the switch.
func (c *ChannelNotifier) NotifyInactiveLinkEvent(chanPoint wire.OutPoint) {
event := InactiveLinkEvent{ChannelPoint: &chanPoint}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send inactive link update: %v", err)
}
}
Expand All @@ -229,7 +229,7 @@ func (c *ChannelNotifier) NotifyInactiveLinkEvent(chanPoint wire.OutPoint) {
// channel is inactive.
func (c *ChannelNotifier) NotifyInactiveChannelEvent(chanPoint wire.OutPoint) {
event := InactiveChannelEvent{ChannelPoint: &chanPoint}
if err := c.ntfnServer.SendUpdate(event); err != nil {
if err := c.ntfnServer.SendUpdate(c.ctx, event); err != nil {
log.Warnf("Unable to send inactive channel update: %v", err)
}
}
20 changes: 12 additions & 8 deletions htlcswitch/htlcnotifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ type HtlcNotifier struct {

ntfnServer *subscribe.Server

// TODO(elle): remove once contexts are provided via each API method.
ctx context.Context

cancel fn.Option[context.CancelFunc]
}

Expand All @@ -86,11 +89,12 @@ func NewHtlcNotifier(now func() time.Time) *HtlcNotifier {
func (h *HtlcNotifier) Start(ctx context.Context) error {
var err error
h.started.Do(func() {
_, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
h.cancel = fn.Some(cancel)
h.ctx = ctx

log.Info("HtlcNotifier starting")
err = h.ntfnServer.Start()
err = h.ntfnServer.Start(ctx)
})
return err
}
Expand All @@ -114,7 +118,7 @@ func (h *HtlcNotifier) Stop() error {
// SubscribeHtlcEvents returns a subscribe.Client that will receive updates
// any time the server is made aware of a new event.
func (h *HtlcNotifier) SubscribeHtlcEvents() (*subscribe.Client, error) {
return h.ntfnServer.Subscribe()
return h.ntfnServer.Subscribe(h.ctx)
}

// HtlcKey uniquely identifies the htlc.
Expand Down Expand Up @@ -335,7 +339,7 @@ func (h *HtlcNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
log.Tracef("Notifying forward event: %v over %v, %v", eventType, key,
info)

if err := h.ntfnServer.SendUpdate(event); err != nil {
if err := h.ntfnServer.SendUpdate(h.ctx, event); err != nil {
log.Warnf("Unable to send forwarding event: %v", err)
}
}
Expand All @@ -359,7 +363,7 @@ func (h *HtlcNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
log.Tracef("Notifying link failure event: %v over %v, %v", eventType,
key, info)

if err := h.ntfnServer.SendUpdate(event); err != nil {
if err := h.ntfnServer.SendUpdate(h.ctx, event); err != nil {
log.Warnf("Unable to send link fail event: %v", err)
}
}
Expand All @@ -380,7 +384,7 @@ func (h *HtlcNotifier) NotifyForwardingFailEvent(key HtlcKey,
log.Tracef("Notifying forwarding failure event: %v over %v", eventType,
key)

if err := h.ntfnServer.SendUpdate(event); err != nil {
if err := h.ntfnServer.SendUpdate(h.ctx, event); err != nil {
log.Warnf("Unable to send forwarding fail event: %v", err)
}
}
Expand All @@ -401,7 +405,7 @@ func (h *HtlcNotifier) NotifySettleEvent(key HtlcKey,

log.Tracef("Notifying settle event: %v over %v", eventType, key)

if err := h.ntfnServer.SendUpdate(event); err != nil {
if err := h.ntfnServer.SendUpdate(h.ctx, event); err != nil {
log.Warnf("Unable to send settle event: %v", err)
}
}
Expand All @@ -422,7 +426,7 @@ func (h *HtlcNotifier) NotifyFinalHtlcEvent(key models.CircuitKey,

log.Tracef("Notifying final settle event: %v", key)

if err := h.ntfnServer.SendUpdate(event); err != nil {
if err := h.ntfnServer.SendUpdate(h.ctx, event); err != nil {
log.Warnf("Unable to send settle event: %v", err)
}
}
Expand Down
14 changes: 9 additions & 5 deletions peernotifier/peernotifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type PeerNotifier struct {

ntfnServer *subscribe.Server

// TODO(elle): remove once contexts are provided via each API method.
ctx context.Context

Check failure on line 21 in peernotifier/peernotifier.go

View workflow job for this annotation

GitHub Actions / lint code

found a struct that contains a context.Context field (containedctx)

cancel fn.Option[context.CancelFunc]
}

Expand Down Expand Up @@ -46,10 +49,11 @@ func (p *PeerNotifier) Start(ctx context.Context) error {

p.started.Do(func() {
log.Info("PeerNotifier starting")
_, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
p.cancel = fn.Some(cancel)
p.ctx = ctx

err = p.ntfnServer.Start()
err = p.ntfnServer.Start(ctx)
})

return err
Expand All @@ -70,7 +74,7 @@ func (p *PeerNotifier) Stop() error {
// SubscribePeerEvents returns a subscribe.Client that will receive updates
// any time the Server is informed of a peer event.
func (p *PeerNotifier) SubscribePeerEvents() (*subscribe.Client, error) {
return p.ntfnServer.Subscribe()
return p.ntfnServer.Subscribe(p.ctx)
}

// NotifyPeerOnline sends a peer online event to all clients subscribed to the
Expand All @@ -80,7 +84,7 @@ func (p *PeerNotifier) NotifyPeerOnline(pubKey [33]byte) {

log.Debugf("PeerNotifier notifying peer: %x online", pubKey)

if err := p.ntfnServer.SendUpdate(event); err != nil {
if err := p.ntfnServer.SendUpdate(p.ctx, event); err != nil {
log.Warnf("Unable to send peer online update: %v", err)
}
}
Expand All @@ -92,7 +96,7 @@ func (p *PeerNotifier) NotifyPeerOffline(pubKey [33]byte) {

log.Debugf("PeerNotifier notifying peer: %x offline", pubKey)

if err := p.ntfnServer.SendUpdate(event); err != nil {
if err := p.ntfnServer.SendUpdate(p.ctx, event); err != nil {
log.Warnf("Unable to send peer offline update: %v", err)
}
}
19 changes: 11 additions & 8 deletions rpcperms/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ type InterceptorChain struct {
// middleware crashes.
mandatoryMiddleware []string

// TODO(elle): remove once contexts are provided via each API method.
ctx context.Context

cancel fn.Option[context.CancelFunc]
sync.RWMutex
}
Expand Down Expand Up @@ -218,10 +221,10 @@ func NewInterceptorChain(log btclog.Logger, noMacaroons bool,
func (r *InterceptorChain) Start(ctx context.Context) error {
var err error
r.started.Do(func() {
_, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
r.cancel = fn.Some(cancel)

err = r.ntfnServer.Start()
err = r.ntfnServer.Start(ctx)
})

return err
Expand All @@ -245,7 +248,7 @@ func (r *InterceptorChain) SetWalletNotCreated() {
defer r.Unlock()

r.state = walletNotCreated
_ = r.ntfnServer.SendUpdate(r.state)
_ = r.ntfnServer.SendUpdate(r.ctx, r.state)
}

// SetWalletLocked moves the RPC state from either walletNotCreated to
Expand All @@ -255,7 +258,7 @@ func (r *InterceptorChain) SetWalletLocked() {
defer r.Unlock()

r.state = walletLocked
_ = r.ntfnServer.SendUpdate(r.state)
_ = r.ntfnServer.SendUpdate(r.ctx, r.state)
}

// SetWalletUnlocked moves the RPC state from either walletNotCreated or
Expand All @@ -265,7 +268,7 @@ func (r *InterceptorChain) SetWalletUnlocked() {
defer r.Unlock()

r.state = walletUnlocked
_ = r.ntfnServer.SendUpdate(r.state)
_ = r.ntfnServer.SendUpdate(r.ctx, r.state)
}

// SetRPCActive moves the RPC state from walletUnlocked to rpcActive.
Expand All @@ -274,7 +277,7 @@ func (r *InterceptorChain) SetRPCActive() {
defer r.Unlock()

r.state = rpcActive
_ = r.ntfnServer.SendUpdate(r.state)
_ = r.ntfnServer.SendUpdate(r.ctx, r.state)
}

// SetServerActive moves the RPC state from walletUnlocked to rpcActive.
Expand All @@ -283,7 +286,7 @@ func (r *InterceptorChain) SetServerActive() {
defer r.Unlock()

r.state = serverActive
_ = r.ntfnServer.SendUpdate(r.state)
_ = r.ntfnServer.SendUpdate(r.ctx, r.state)
}

// rpcStateToWalletState converts rpcState to lnrpc.WalletState. Returns
Expand Down Expand Up @@ -332,7 +335,7 @@ func (r *InterceptorChain) SubscribeState(_ *lnrpc.SubscribeStateRequest,
}

// Subscribe to state updates.
client, err := r.ntfnServer.Subscribe()
client, err := r.ntfnServer.Subscribe(stream.Context())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8856,7 +8856,7 @@ func (r *rpcServer) SubscribeCustomMessages(
_ *lnrpc.SubscribeCustomMessagesRequest,
server lnrpc.Lightning_SubscribeCustomMessagesServer) error {

client, err := r.server.SubscribeCustomMessages()
client, err := r.server.SubscribeCustomMessages(server.Context())
if err != nil {
return err
}
Expand Down
10 changes: 6 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2068,7 +2068,7 @@ func (s *server) Start(ctx context.Context) error {
s.cancel = fn.Some(cancel)

cleanup = cleanup.add(s.customMessageServer.Stop)
if err := s.customMessageServer.Start(); err != nil {
if err := s.customMessageServer.Start(ctx); err != nil {
startErr = err
return
}
Expand Down Expand Up @@ -4067,16 +4067,18 @@ func (s *server) handleCustomMessage(peer [33]byte, msg *lnwire.Custom) error {
srvrLog.Debugf("Custom message received: peer=%x, type=%d",
peer, msg.Type)

return s.customMessageServer.SendUpdate(&CustomMessage{
return s.customMessageServer.SendUpdate(s.ctx, &CustomMessage{
Peer: peer,
Msg: msg,
})
}

// SubscribeCustomMessages subscribes to a stream of incoming custom peer
// messages.
func (s *server) SubscribeCustomMessages() (*subscribe.Client, error) {
return s.customMessageServer.Subscribe()
func (s *server) SubscribeCustomMessages(ctx context.Context) (
*subscribe.Client, error) {

return s.customMessageServer.Subscribe(ctx)
}

// peerConnected is a function that handles initialization a newly connected
Expand Down
Loading

0 comments on commit 94aac23

Please sign in to comment.