Skip to content

Commit

Permalink
multi: remove kvdb.Tx from ChannelGraphSource.ForAllOutgoingChannels
Browse files Browse the repository at this point in the history
and the same for ChannelStateDB.FetchChannel. Most of the calls to these
methods provide a `nil` Tx anyways. The only place that currently
provides a non-nil tx is in the `localchans.Manager`. It takes the
transaction provided to the `ForAllOutgoingChannels` callback and passes
it to it's `updateEdge` method. Note, however, that the
`ForAllOutgoingChannels` call is a call to the graph db and the call to
`updateEdge` is a call to the `ChannelStateDB`. There is no reason that
these two calls need to happen under the same transaction as they are
reading from two completely disjoint databases. And so in the effort to
completely split untangle the relationship between the two databases, we
now dont use the same transaction for these two calls.
  • Loading branch information
ellemouton committed Nov 28, 2024
1 parent 6e13898 commit adcaa88
Show file tree
Hide file tree
Showing 15 changed files with 43 additions and 65 deletions.
6 changes: 2 additions & 4 deletions chanbackup/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/kvdb"
)

// LiveChannelSource is an interface that allows us to query for the set of
Expand All @@ -18,8 +17,7 @@ type LiveChannelSource interface {

// FetchChannel attempts to locate a live channel identified by the
// passed chanPoint. Optionally an existing db tx can be supplied.
FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error)
FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error)
}

// assembleChanBackup attempts to assemble a static channel backup for the
Expand Down Expand Up @@ -97,7 +95,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,

// First, we'll query the channel source to see if the channel is known
// and open within the database.
targetChan, err := chanSource.FetchChannel(nil, chanPoint)
targetChan, err := chanSource.FetchChannel(chanPoint)
if err != nil {
// If we can't find the channel, then we return with an error,
// as we have nothing to backup.
Expand Down
3 changes: 1 addition & 2 deletions chanbackup/backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -40,7 +39,7 @@ func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error)
return chans, nil
}

func (m *mockChannelSource) FetchChannel(_ kvdb.RTx, chanPoint wire.OutPoint) (
func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error) {

if m.failQuery {
Expand Down
9 changes: 4 additions & 5 deletions channeldb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,8 @@ func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (

// FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*OpenChannel, error) {
func (c *ChannelStateDB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel,
error) {

var targetChanPoint bytes.Buffer
err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint)
Expand All @@ -674,7 +673,7 @@ func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
return targetChanPointBytes, &chanPoint, nil
}

return c.channelScanner(tx, selector)
return c.channelScanner(nil, selector)
}

// FetchChannelByID attempts to locate a channel specified by the passed channel
Expand Down Expand Up @@ -1366,7 +1365,7 @@ func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint,
// With the chanPoint constructed, we'll attempt to find the target
// channel in the database. If we can't find the channel, then we'll
// return the error back to the caller.
dbChan, err := c.FetchChannel(nil, *chanPoint)
dbChan, err := c.FetchChannel(*chanPoint)
switch {
// If the channel wasn't found, then it's possible that it was already
// abandoned from the database.
Expand Down
8 changes: 4 additions & 4 deletions channeldb/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func TestFetchChannel(t *testing.T) {
channelState := createTestChannel(t, cdb, openChannelOption())

// Next, attempt to fetch the channel by its chan point.
dbChannel, err := cdb.FetchChannel(nil, channelState.FundingOutpoint)
dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint)
require.NoError(t, err, "unable to fetch channel")

// The decoded channel state should be identical to what we stored
Expand All @@ -270,7 +270,7 @@ func TestFetchChannel(t *testing.T) {
uniqueOutputIndex.Add(1)
channelState2.FundingOutpoint.Index = uniqueOutputIndex.Load()

_, err = cdb.FetchChannel(nil, channelState2.FundingOutpoint)
_, err = cdb.FetchChannel(channelState2.FundingOutpoint)
require.ErrorIs(t, err, ErrChannelNotFound)

chanID2 := lnwire.NewChanIDFromOutPoint(channelState2.FundingOutpoint)
Expand Down Expand Up @@ -410,7 +410,7 @@ func TestRestoreChannelShells(t *testing.T) {

// We should also be able to find the channel if we query for it
// directly.
_, err = cdb.FetchChannel(nil, channelShell.Chan.FundingOutpoint)
_, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint)
require.NoError(t, err, "unable to fetch channel")

// We should also be able to find the link node that was inserted by
Expand Down Expand Up @@ -459,7 +459,7 @@ func TestAbandonChannel(t *testing.T) {

// At this point, the channel should no longer be found in the set of
// open channels.
_, err = cdb.FetchChannel(nil, chanState.FundingOutpoint)
_, err = cdb.FetchChannel(chanState.FundingOutpoint)
if err != ErrChannelNotFound {
t.Fatalf("channel should not have been found: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion channelnotifier/channelnotifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (c *ChannelNotifier) NotifyPendingOpenChannelEvent(chanPoint wire.OutPoint,
// channel has gone from pending open to open.
func (c *ChannelNotifier) NotifyOpenChannelEvent(chanPoint wire.OutPoint) {
// Fetch the relevant channel from the database.
channel, err := c.chanDB.FetchChannel(nil, chanPoint)
channel, err := c.chanDB.FetchChannel(chanPoint)
if err != nil {
log.Warnf("Unable to fetch open channel from the db: %v", err)
}
Expand Down
8 changes: 2 additions & 6 deletions contractcourt/chain_arbitrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
// same instance that is used by the link.
chanPoint := a.channel.FundingOutpoint

channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -359,9 +357,7 @@ func (a *arbChannel) ForceCloseChan() (*wire.MsgTx, error) {
// Now that we know the link can't mutate the channel
// state, we'll read the channel from disk the target
// channel according to its channel point.
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint)
if err != nil {
return nil, err
}
Expand Down
2 changes: 0 additions & 2 deletions discovery/gossiper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
Expand Down Expand Up @@ -1637,7 +1636,6 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error {
edgesToUpdate []updateTuple
)
err := d.cfg.Graph.ForAllOutgoingChannels(func(
_ kvdb.RTx,
info *models.ChannelEdgeInfo,
edge *models.ChannelEdgePolicy) error {

Expand Down
9 changes: 3 additions & 6 deletions discovery/gossiper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntest/wait"
Expand Down Expand Up @@ -207,9 +206,8 @@ func (r *mockGraphSource) ForEachNode(func(node *models.LightningNode) error) er
return nil
}

func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
i *models.ChannelEdgeInfo,
c *models.ChannelEdgePolicy) error) error {
func (r *mockGraphSource) ForAllOutgoingChannels(cb func(
i *models.ChannelEdgeInfo, c *models.ChannelEdgePolicy) error) error {

r.mu.Lock()
defer r.mu.Unlock()
Expand All @@ -231,7 +229,7 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx,
}

for _, channel := range chans {
if err := cb(nil, channel.Info, channel.Policy1); err != nil {
if err := cb(channel.Info, channel.Policy1); err != nil {
return err
}
}
Expand Down Expand Up @@ -3483,7 +3481,6 @@ out:
const newTimeLockDelta = 100
var edgesToUpdate []EdgeWithInfo
err = ctx.router.ForAllOutgoingChannels(func(
_ kvdb.RTx,
info *models.ChannelEdgeInfo,
edge *models.ChannelEdgePolicy) error {

Expand Down
8 changes: 4 additions & 4 deletions graph/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1645,11 +1645,11 @@ func (b *Builder) ForEachNode(
// the router.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx,
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error {
func (b *Builder) ForAllOutgoingChannels(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy) error) error {

return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode,
func(tx kvdb.RTx, c *models.ChannelEdgeInfo,
func(_ kvdb.RTx, c *models.ChannelEdgeInfo,
e *models.ChannelEdgePolicy,
_ *models.ChannelEdgePolicy) error {

Expand All @@ -1658,7 +1658,7 @@ func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx,
"has no policy")
}

return cb(tx, c, e)
return cb(c, e)
},
)
}
Expand Down
3 changes: 1 addition & 2 deletions graph/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ type ChannelGraphSource interface {
// ForAllOutgoingChannels is used to iterate over all channels
// emanating from the "source" node which is the center of the
// star-graph.
ForAllOutgoingChannels(cb func(tx kvdb.RTx,
c *models.ChannelEdgeInfo,
ForAllOutgoingChannels(cb func(c *models.ChannelEdgeInfo,
e *models.ChannelEdgePolicy) error) error

// CurrentBlockHeight returns the block height from POV of the router
Expand Down
2 changes: 1 addition & 1 deletion pilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot,
ChannelInfo: func(chanPoint wire.OutPoint) (
*autopilot.LocalChannel, error) {

channel, err := svr.chanStateDB.FetchChannel(nil, chanPoint)
channel, err := svr.chanStateDB.FetchChannel(chanPoint)
if err != nil {
return nil, err
}
Expand Down
26 changes: 10 additions & 16 deletions routing/localchans/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing"
Expand All @@ -40,14 +39,13 @@ type Manager struct {

// ForAllOutgoingChannels is required to iterate over all our local
// channels. The ChannelEdgePolicy parameter may be nil.
ForAllOutgoingChannels func(cb func(kvdb.RTx,
*models.ChannelEdgeInfo,
ForAllOutgoingChannels func(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy) error) error

// FetchChannel is used to query local channel parameters. Optionally an
// existing db tx can be supplied.
FetchChannel func(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error)
FetchChannel func(chanPoint wire.OutPoint) (*channeldb.OpenChannel,
error)

// AddEdge is used to add edge/channel to the topology of the router.
AddEdge func(edge *models.ChannelEdgeInfo) error
Expand Down Expand Up @@ -83,9 +81,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy,
policiesToUpdate := make(map[wire.OutPoint]models.ForwardingPolicy)

// NOTE: edge may be nil when this function is called.
processChan := func(
tx kvdb.RTx,
info *models.ChannelEdgeInfo,
processChan := func(info *models.ChannelEdgeInfo,
edge *models.ChannelEdgePolicy) error {

// If we have a channel filter, and this channel isn't a part
Expand Down Expand Up @@ -114,9 +110,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy,
}

// Apply the new policy to the edge.
err := r.updateEdge(
tx, info.ChannelPoint, edge, newSchema,
)
err := r.updateEdge(info.ChannelPoint, edge, newSchema)
if err != nil {
failedUpdates = append(failedUpdates,
makeFailureItem(info.ChannelPoint,
Expand Down Expand Up @@ -164,7 +158,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy,

// Construct a list of failed policy updates.
for chanPoint := range unprocessedChans {
channel, err := r.FetchChannel(nil, chanPoint)
channel, err := r.FetchChannel(chanPoint)
switch {
case errors.Is(err, channeldb.ErrChannelNotFound):
failedUpdates = append(failedUpdates,
Expand Down Expand Up @@ -203,7 +197,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy,
channel, newSchema,
)
if failedUpdate == nil {
err = processChan(nil, info, edge)
err = processChan(info, edge)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -261,7 +255,7 @@ func (r *Manager) createMissingEdge(channel *channeldb.OpenChannel,

// Validate the newly created edge policy with the user defined new
// schema before adding the edge to the database.
err = r.updateEdge(nil, channel.FundingOutpoint, edge, newSchema)
err = r.updateEdge(channel.FundingOutpoint, edge, newSchema)
if err != nil {
return nil, nil, makeFailureItem(
info.ChannelPoint,
Expand Down Expand Up @@ -351,11 +345,11 @@ func (r *Manager) createEdge(channel *channeldb.OpenChannel,
}

// updateEdge updates the given edge with the new schema.
func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint,
func (r *Manager) updateEdge(chanPoint wire.OutPoint,
edge *models.ChannelEdgePolicy,
newSchema routing.ChannelPolicy) error {

channel, err := r.FetchChannel(tx, chanPoint)
channel, err := r.FetchChannel(chanPoint)
if err != nil {
return err
}
Expand Down
10 changes: 4 additions & 6 deletions routing/localchans/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing"
Expand Down Expand Up @@ -123,20 +122,19 @@ func TestManager(t *testing.T) {
return nil
}

forAllOutgoingChannels := func(cb func(kvdb.RTx,
*models.ChannelEdgeInfo,
forAllOutgoingChannels := func(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy) error) error {

for _, c := range channelSet {
if err := cb(nil, c.edgeInfo, &currentPolicy); err != nil {
if err := cb(c.edgeInfo, &currentPolicy); err != nil {
return err
}
}
return nil
}

fetchChannel := func(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error) {
fetchChannel := func(chanPoint wire.OutPoint) (*channeldb.OpenChannel,
error) {

if chanPoint == chanPointMissing {
return &channeldb.OpenChannel{}, channeldb.ErrChannelNotFound
Expand Down
4 changes: 2 additions & 2 deletions rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2692,7 +2692,7 @@ func (r *rpcServer) CloseChannel(in *lnrpc.CloseChannelRequest,

// First, we'll fetch the channel as is, as we'll need to examine it
// regardless of if this is a force close or not.
channel, err := r.server.chanStateDB.FetchChannel(nil, *chanPoint)
channel, err := r.server.chanStateDB.FetchChannel(*chanPoint)
if err != nil {
return err
}
Expand Down Expand Up @@ -3140,7 +3140,7 @@ func (r *rpcServer) AbandonChannel(_ context.Context,
return nil, err
}

dbChan, err := r.server.chanStateDB.FetchChannel(nil, *chanPoint)
dbChan, err := r.server.chanStateDB.FetchChannel(*chanPoint)
switch {
// If the channel isn't found in the set of open channels, then we can
// continue on as it can't be loaded into the link/peer.
Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,17 +1130,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.localChanMgr = &localchans.Manager{
SelfPub: nodeKeyDesc.PubKey,
DefaultRoutingPolicy: cc.RoutingPolicy,
ForAllOutgoingChannels: func(cb func(kvdb.RTx,
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error {
ForAllOutgoingChannels: func(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy) error) error {

return s.graphDB.ForEachNodeChannel(selfVertex,
func(tx kvdb.RTx, c *models.ChannelEdgeInfo,
func(_ kvdb.RTx, c *models.ChannelEdgeInfo,
e *models.ChannelEdgePolicy,
_ *models.ChannelEdgePolicy) error {

// NOTE: The invoked callback here may
// receive a nil channel policy.
return cb(tx, c, e)
return cb(c, e)
},
)
},
Expand Down

0 comments on commit adcaa88

Please sign in to comment.