From ac983a8dea513c27ffbe85d5c6ea5b42819dfebe Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 08:42:23 +0200 Subject: [PATCH 01/33] lnwire: make MuSig2Nonce TLV type re-usable Before this commit, any MuSig2Nonce TLV field used within a message is expected to use the same tlv type number. This is changed in this commit so that each message must specify which type number it wishes to use. This is necessary for if there is ever more than one MuSig2Nonce used within the same message. --- lnwire/accept_channel.go | 24 ++++++++++++++++++------ lnwire/channel_ready.go | 23 ++++++++++++++++++----- lnwire/channel_reestablish.go | 23 ++++++++++++++++------- lnwire/musig2.go | 25 ++++++++++++++++++------- lnwire/open_channel.go | 25 +++++++++++++++++++------ lnwire/revoke_and_ack.go | 21 ++++++++++++++++----- 6 files changed, 105 insertions(+), 36 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index 66dda815c6..842ee6eb3b 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -9,6 +9,12 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +const ( + // AcceptChanLocalNonceType is the tlv number associated with the + // local nonce TLV record in the accept_channel message. + AcceptChanLocalNonceType = tlv.Type(4) +) + // AcceptChannel is the message Bob sends to Alice after she initiates the // single funder channel workflow via an AcceptChannel message. Once Alice // receives Bob's response, then she has all the items necessary to construct @@ -142,7 +148,12 @@ func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { recordProducers = append(recordProducers, a.LeaseExpiry) } if a.LocalNonce != nil { - recordProducers = append(recordProducers, a.LocalNonce) + recordProducers = append(recordProducers, + &Musig2NonceRecordProducer{ + Musig2Nonce: *a.LocalNonce, + Type: AcceptChanLocalNonceType, + }, + ) } err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) if err != nil { @@ -248,11 +259,12 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { var ( chanType ChannelType leaseExpiry LeaseExpiry - localNonce Musig2Nonce + localNonce = NewMusig2NonceRecordProducer( + AcceptChanLocalNonceType, + ) ) typeMap, err := tlvRecords.ExtractRecords( - &a.UpfrontShutdownScript, &chanType, &leaseExpiry, - &localNonce, + &a.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, ) if err != nil { return err @@ -265,8 +277,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { a.LeaseExpiry = &leaseExpiry } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - a.LocalNonce = &localNonce + if val, ok := typeMap[AcceptChanLocalNonceType]; ok && val == nil { + a.LocalNonce = &localNonce.Musig2Nonce } a.ExtraData = tlvRecords diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 07872f8006..48a0cb0ed2 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -8,6 +8,12 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +const ( + // ChanReadyLocalNonceType is the tlv number associated with the local + // nonce TLV record in the channel_ready message. + ChanReadyLocalNonceType = tlv.Type(4) +) + // ChannelReady is the message that both parties to a new channel creation // send once they have observed the funding transaction being confirmed on the // blockchain. ChannelReady contains the signatures necessary for the channel @@ -77,10 +83,12 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { // the AliasScidRecordType. var ( aliasScid ShortChannelID - localNonce Musig2Nonce + localNonce = NewMusig2NonceRecordProducer( + ChanReadyLocalNonceType, + ) ) typeMap, err := tlvRecords.ExtractRecords( - &aliasScid, &localNonce, + &aliasScid, localNonce, ) if err != nil { return err @@ -91,8 +99,8 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { c.AliasScid = &aliasScid } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - c.NextLocalNonce = &localNonce + if val, ok := typeMap[ChanReadyLocalNonceType]; ok && val == nil { + c.NextLocalNonce = &localNonce.Musig2Nonce } if len(tlvRecords) != 0 { @@ -122,7 +130,12 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { recordProducers = append(recordProducers, c.AliasScid) } if c.NextLocalNonce != nil { - recordProducers = append(recordProducers, c.NextLocalNonce) + recordProducers = append( + recordProducers, &Musig2NonceRecordProducer{ + Musig2Nonce: *c.NextLocalNonce, + Type: ChanReadyLocalNonceType, + }, + ) } err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 1b6cfdffc3..efa8d946bc 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -8,6 +8,12 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +const ( + // ChanReestLocalNonceType is the tlv number associated with the local + // nonce TLV record in the channel_reestablish message. + ChanReestLocalNonceType = tlv.Type(4) +) + // ChannelReestablish is a message sent between peers that have an existing // open channel upon connection reestablishment. This message allows both sides // to report their local state, and their current knowledge of the state of the @@ -119,7 +125,12 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) if a.LocalNonce != nil { - recordProducers = append(recordProducers, a.LocalNonce) + recordProducers = append(recordProducers, + &Musig2NonceRecordProducer{ + Musig2Nonce: *a.LocalNonce, + Type: ChanReestLocalNonceType, + }, + ) } err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) if err != nil { @@ -179,16 +190,14 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { return err } - var localNonce Musig2Nonce - typeMap, err := tlvRecords.ExtractRecords( - &localNonce, - ) + localNonce := NewMusig2NonceRecordProducer(ChanReestLocalNonceType) + typeMap, err := tlvRecords.ExtractRecords(localNonce) if err != nil { return err } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - a.LocalNonce = &localNonce + if val, ok := typeMap[ChanReestLocalNonceType]; ok && val == nil { + a.LocalNonce = &localNonce.Musig2Nonce } if len(tlvRecords) != 0 { diff --git a/lnwire/musig2.go b/lnwire/musig2.go index 6602ee6947..5d7c47195e 100644 --- a/lnwire/musig2.go +++ b/lnwire/musig2.go @@ -7,20 +7,31 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -const ( - // NonceRecordType is the TLV type used to encode a local musig2 nonce. - NonceRecordType tlv.Type = 4 -) - // Musig2Nonce represents a musig2 public nonce, which is the concatenation of // two EC points serialized in compressed format. type Musig2Nonce [musig2.PubNonceSize]byte +// Musig2NonceRecordProducer wraps a Musig2Nonce with the tlv type it should be +// encoded under. This can then be used to produce a TLV record. +type Musig2NonceRecordProducer struct { + Musig2Nonce + + Type tlv.Type +} + +// NewMusig2NonceRecordProducer constructs a new Musig2NonceRecordProducer with +// the given tlv type set. +func NewMusig2NonceRecordProducer(tlvType tlv.Type) *Musig2NonceRecordProducer { + return &Musig2NonceRecordProducer{ + Type: tlvType, + } +} + // Record returns a TLV record that can be used to encode/decode the musig2 // nonce from a given TLV stream. -func (m *Musig2Nonce) Record() tlv.Record { +func (m *Musig2NonceRecordProducer) Record() tlv.Record { return tlv.MakeStaticRecord( - NonceRecordType, m, musig2.PubNonceSize, nonceTypeEncoder, + m.Type, &m.Musig2Nonce, musig2.PubNonceSize, nonceTypeEncoder, nonceTypeDecoder, ) } diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9cb4bc41ad..c797801d00 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -21,6 +21,12 @@ const ( FFAnnounceChannel FundingFlag = 1 << iota ) +const ( + // OpenChanLocalNonceType is the tlv number associated with the local + // nonce TLV record in the open_channel message. + OpenChanLocalNonceType = tlv.Type(4) +) + // OpenChannel is the message Alice sends to Bob if we should like to create a // channel with Bob where she's the sole provider of funds to the channel. // Single funder channels simplify the initial funding workflow, are supported @@ -176,8 +182,14 @@ func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { recordProducers = append(recordProducers, o.LeaseExpiry) } if o.LocalNonce != nil { - recordProducers = append(recordProducers, o.LocalNonce) + recordProducers = append(recordProducers, + &Musig2NonceRecordProducer{ + Musig2Nonce: *o.LocalNonce, + Type: OpenChanLocalNonceType, + }, + ) } + err := EncodeMessageExtraData(&o.ExtraData, recordProducers...) if err != nil { return err @@ -302,11 +314,12 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { var ( chanType ChannelType leaseExpiry LeaseExpiry - localNonce Musig2Nonce + localNonce = NewMusig2NonceRecordProducer( + OpenChanLocalNonceType, + ) ) typeMap, err := tlvRecords.ExtractRecords( - &o.UpfrontShutdownScript, &chanType, &leaseExpiry, - &localNonce, + &o.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, ) if err != nil { return err @@ -319,8 +332,8 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { o.LeaseExpiry = &leaseExpiry } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - o.LocalNonce = &localNonce + if val, ok := typeMap[OpenChanLocalNonceType]; ok && val == nil { + o.LocalNonce = &localNonce.Musig2Nonce } o.ExtraData = tlvRecords diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 6b6b801671..673f31a937 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -8,6 +8,12 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) +const ( + // RevAndAckLocalNonceType is the tlv number associated with the local + // nonce TLV record in the revoke_and_ack message. + RevAndAckLocalNonceType = tlv.Type(4) +) + // RevokeAndAck is sent by either side once a CommitSig message has been // received, and validated. This message serves to revoke the prior commitment // transaction, which was the most up to date version until a CommitSig message @@ -74,15 +80,15 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { return err } - var musigNonce Musig2Nonce - typeMap, err := tlvRecords.ExtractRecords(&musigNonce) + musigNonce := NewMusig2NonceRecordProducer(RevAndAckLocalNonceType) + typeMap, err := tlvRecords.ExtractRecords(musigNonce) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - c.LocalNonce = &musigNonce + if val, ok := typeMap[RevAndAckLocalNonceType]; ok && val == nil { + c.LocalNonce = &musigNonce.Musig2Nonce } if len(tlvRecords) != 0 { @@ -99,7 +105,12 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) if c.LocalNonce != nil { - recordProducers = append(recordProducers, c.LocalNonce) + recordProducers = append(recordProducers, + &Musig2NonceRecordProducer{ + Musig2Nonce: *c.LocalNonce, + Type: RevAndAckLocalNonceType, + }, + ) } err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { From 780e7d6f0e3881ebb703ca6da560348e0ef6b350 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 09:11:33 +0200 Subject: [PATCH 02/33] lnwire: add RawFeatureVectorRecordProducer This commits defines the RawFeatureVectorRecordProducer type which will allow RawFeatureVector type to be used for a TLV record. --- lnwire/features.go | 67 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/lnwire/features.go b/lnwire/features.go index 81472a7544..93d255c3d5 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "io" + + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -372,7 +374,7 @@ func (fv RawFeatureVector) Equals(other *RawFeatureVector) bool { return true } -// Merges sets all feature bits in other on the receiver's feature vector. +// Merge sets all feature bits in other on the receiver's feature vector. func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error { for bit := range other.features { err := fv.SafeSet(bit) @@ -721,3 +723,66 @@ func (fv *FeatureVector) Clone() *FeatureVector { features := fv.RawFeatureVector.Clone() return NewFeatureVector(features, fv.featureNames) } + +// featureBitLen returns the length in bytes of the encoded feature bits. +func (fv *RawFeatureVector) featureBitLen() uint64 { + return uint64(fv.SerializeSize()) +} + +// RawFeatureVectorRecordProducer wraps a RawFeatureVector with the TLV type +// that it should be encoded with. +type RawFeatureVectorRecordProducer struct { + RawFeatureVector + + Type tlv.Type +} + +// NewRawFeatureVectorRecord constructs a new RawFeatureVectorRecordProducer +// with the given TLV type. +func NewRawFeatureVectorRecord( + tlvType tlv.Type) *RawFeatureVectorRecordProducer { + + return &RawFeatureVectorRecordProducer{ + Type: tlvType, + } +} + +// Record returns a TLV record that can be used to encode/decode the channel +// type from a given TLV stream. +func (r *RawFeatureVectorRecordProducer) Record() tlv.Record { + return tlv.MakeDynamicRecord( + r.Type, &r.RawFeatureVector, r.featureBitLen, + rawFeatureVectorEncoder, rawFeatureVectorDecoder, + ) +} + +// rawFeatureVectorEncoder is a custom TLV encoder for a RawFeatureVector +// record. +func rawFeatureVectorEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*RawFeatureVector); ok { + // Encode the feature bits as a byte slice without its length + // prepended, as that's already taken care of by the TLV record. + fv := *v + return fv.encode(w, fv.SerializeSize(), 8) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector") +} + +// rawFeatureVectorDecoder is a custom TLV decoder for a RawFeatureVector +// record. +func rawFeatureVectorDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*RawFeatureVector); ok { + fv := NewRawFeatureVector() + if err := fv.decode(r, int(l), 8); err != nil { + return err + } + *v = *fv + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector") +} From 10ce8d7cdef28495bf9e57766448db3fe6bb19f4 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 09:28:04 +0200 Subject: [PATCH 03/33] lnwire: add Encode and Pack methods for tlv.Records --- htlcswitch/failure_test.go | 2 +- lnwire/accept_channel.go | 2 +- lnwire/channel_ready.go | 2 +- lnwire/channel_reestablish.go | 2 +- lnwire/channel_type_test.go | 4 +-- lnwire/closing_signed.go | 2 +- lnwire/commit_sig.go | 2 +- lnwire/extra_bytes.go | 50 ++++++++++++++++++++---------- lnwire/extra_bytes_test.go | 7 ++--- lnwire/funding_created.go | 2 +- lnwire/funding_signed.go | 2 +- lnwire/open_channel.go | 2 +- lnwire/revoke_and_ack.go | 2 +- lnwire/short_channel_id_test.go | 4 +-- lnwire/shutdown.go | 2 +- lnwire/typed_delivery_addr_test.go | 4 +-- lnwire/typed_lease_expiry_test.go | 4 +-- 17 files changed, 56 insertions(+), 39 deletions(-) diff --git a/htlcswitch/failure_test.go b/htlcswitch/failure_test.go index 48ebc66821..b1bb5b402e 100644 --- a/htlcswitch/failure_test.go +++ b/htlcswitch/failure_test.go @@ -65,7 +65,7 @@ func TestLongFailureMessage(t *testing.T) { var value varBytesRecordProducer extraData := incorrectDetails.ExtraOpaqueData() - typeMap, err := extraData.ExtractRecords(&value) + typeMap, err := extraData.ExtractRecordsFromProducers(&value) require.NoError(t, err) require.Len(t, typeMap, 1) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index 842ee6eb3b..3c99e00231 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -263,7 +263,7 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { AcceptChanLocalNonceType, ) ) - typeMap, err := tlvRecords.ExtractRecords( + typeMap, err := tlvRecords.ExtractRecordsFromProducers( &a.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, ) if err != nil { diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 48a0cb0ed2..bfd15e08a4 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -87,7 +87,7 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { ChanReadyLocalNonceType, ) ) - typeMap, err := tlvRecords.ExtractRecords( + typeMap, err := tlvRecords.ExtractRecordsFromProducers( &aliasScid, localNonce, ) if err != nil { diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index efa8d946bc..044702edaa 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -191,7 +191,7 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { } localNonce := NewMusig2NonceRecordProducer(ChanReestLocalNonceType) - typeMap, err := tlvRecords.ExtractRecords(localNonce) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(localNonce) if err != nil { return err } diff --git a/lnwire/channel_type_test.go b/lnwire/channel_type_test.go index dd8e02439a..1d477151c4 100644 --- a/lnwire/channel_type_test.go +++ b/lnwire/channel_type_test.go @@ -17,10 +17,10 @@ func TestChannelTypeEncodeDecode(t *testing.T) { )) var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecords(&chanType)) + require.NoError(t, extraData.PackRecordsFromProducers(&chanType)) var chanType2 ChannelType - tlvs, err := extraData.ExtractRecords(&chanType2) + tlvs, err := extraData.ExtractRecordsFromProducers(&chanType2) require.NoError(t, err) require.Contains(t, tlvs, ChannelTypeRecordType) diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 3e3651964d..9b68ca11a1 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -79,7 +79,7 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { var ( partialSig PartialSig ) - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(&partialSig) if err != nil { return err } diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index d25d36a8ad..21a39759d0 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -84,7 +84,7 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { var ( partialSig PartialSigWithNonce ) - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(&partialSig) if err != nil { return err } diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 200f313ca8..8f5668efa0 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -51,13 +51,7 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error { // PackRecords attempts to encode the set of tlv records into the target // ExtraOpaqueData instance. The records will be encoded as a raw TLV stream // and stored within the backing slice pointer. -func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) error { - // First, assemble all the records passed in in series. - records := make([]tlv.Record, 0, len(recordProducers)) - for _, producer := range recordProducers { - records = append(records, producer.Record()) - } - +func (e *ExtraOpaqueData) PackRecords(records ...tlv.Record) error { // Ensure that the set of records are sorted before we encode them into // the stream, to ensure they're canonical. tlv.SortRecords(records) @@ -72,24 +66,32 @@ func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) err return err } - *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + *e = extraBytesWriter.Bytes() return nil } -// ExtractRecords attempts to decode any types in the internal raw bytes as if -// it were a tlv stream. The set of raw parsed types is returned, and any -// passed records (if found in the stream) will be parsed into the proper -// tlv.Record. -func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) ( - tlv.TypeMap, error) { +// PackRecordsFromProducers attempts to encode the set of tlv records into the +// target ExtraOpaqueData instance. The records will be encoded as a raw TLV +// stream and stored within the backing slice pointer. +func (e *ExtraOpaqueData) PackRecordsFromProducers( + recordProducers ...tlv.RecordProducer) error { - // First, assemble all the records passed in in series. + // First, assemble all the records passed in, in series. records := make([]tlv.Record, 0, len(recordProducers)) for _, producer := range recordProducers { records = append(records, producer.Record()) } + return e.PackRecords(records...) +} + +// ExtractRecords attempts to decode any types in the internal raw bytes as if +// it were a tlv stream. The set of raw parsed types is returned, and any passed +// records (if found in the stream) will be parsed into the proper tlv.Record. +func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) (tlv.TypeMap, + error) { + // Ensure that the set of records are sorted before we attempt to // decode from the stream, to ensure they're canonical. tlv.SortRecords(records) @@ -106,6 +108,22 @@ func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) return tlvStream.DecodeWithParsedTypesP2P(extraBytesReader) } +// ExtractRecordsFromProducers attempts to decode any types in the internal raw +// bytes as if it were a tlv stream. The set of raw parsed types is returned, +// and any records produced by the passed record producers (if found in the +// stream) will be parsed into the proper tlv.Record. +func (e *ExtraOpaqueData) ExtractRecordsFromProducers( + recordProducers ...tlv.RecordProducer) (tlv.TypeMap, error) { + + // First, assemble all the records passed in, in series. + records := make([]tlv.Record, 0, len(recordProducers)) + for _, producer := range recordProducers { + records = append(records, producer.Record()) + } + + return e.ExtractRecords(records...) +} + // EncodeMessageExtraData encodes the given recordProducers into the given // extraData. func EncodeMessageExtraData(extraData *ExtraOpaqueData, @@ -119,5 +137,5 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData, // Pack in the series of TLV records into this message. The order we // pass them in doesn't matter, as the method will ensure that things // are all properly sorted. - return extraData.PackRecords(recordProducers...) + return extraData.PackRecordsFromProducers(recordProducers...) } diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index fd9f28841d..46100d7cff 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -118,9 +118,8 @@ func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { // Now that we have our set of sample records and types, we'll encode // them into the passed ExtraOpaqueData instance. var extraBytes ExtraOpaqueData - if err := extraBytes.PackRecords(testRecordsProducers...); err != nil { - t.Fatalf("unable to pack records: %v", err) - } + err := extraBytes.PackRecordsFromProducers(testRecordsProducers...) + require.NoError(t, err) // We'll now simulate decoding these types _back_ into records on the // other side. @@ -128,7 +127,7 @@ func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { &recordProducer{tlv.MakePrimitiveRecord(type1, &channelType2)}, &recordProducer{tlv.MakePrimitiveRecord(type2, &hop2)}, } - typeMap, err := extraBytes.ExtractRecords(newRecords...) + typeMap, err := extraBytes.ExtractRecordsFromProducers(newRecords...) require.NoError(t, err, "unable to extract record") // We should find that the new backing values have been populated with diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index f8128ff761..06f15c9dfc 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -95,7 +95,7 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { var ( partialSig PartialSigWithNonce ) - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(&partialSig) if err != nil { return err } diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index c7fb03d155..1bf1fc234c 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -81,7 +81,7 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { var ( partialSig PartialSigWithNonce ) - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(&partialSig) if err != nil { return err } diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index c797801d00..90e5adbe76 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -318,7 +318,7 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { OpenChanLocalNonceType, ) ) - typeMap, err := tlvRecords.ExtractRecords( + typeMap, err := tlvRecords.ExtractRecordsFromProducers( &o.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, ) if err != nil { diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 673f31a937..80dd847269 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -81,7 +81,7 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { } musigNonce := NewMusig2NonceRecordProducer(RevAndAckLocalNonceType) - typeMap, err := tlvRecords.ExtractRecords(musigNonce) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(musigNonce) if err != nil { return err } diff --git a/lnwire/short_channel_id_test.go b/lnwire/short_channel_id_test.go index 2916f20d17..c440a0e370 100644 --- a/lnwire/short_channel_id_test.go +++ b/lnwire/short_channel_id_test.go @@ -53,10 +53,10 @@ func TestScidTypeEncodeDecode(t *testing.T) { } var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecords(&aliasScid)) + require.NoError(t, extraData.PackRecordsFromProducers(&aliasScid)) var aliasScid2 ShortChannelID - tlvs, err := extraData.ExtractRecords(&aliasScid2) + tlvs, err := extraData.ExtractRecordsFromProducers(&aliasScid2) require.NoError(t, err) require.Contains(t, tlvs, AliasScidRecordType) diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 5b59b47ab1..facfe79232 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -103,7 +103,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { } var musigNonce ShutdownNonce - typeMap, err := tlvRecords.ExtractRecords(&musigNonce) + typeMap, err := tlvRecords.ExtractRecordsFromProducers(&musigNonce) if err != nil { return err } diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go index 9d00bc8bd8..0b8274750e 100644 --- a/lnwire/typed_delivery_addr_test.go +++ b/lnwire/typed_delivery_addr_test.go @@ -15,13 +15,13 @@ func TestDeliveryAddressEncodeDecode(t *testing.T) { ) var extraData ExtraOpaqueData - err := extraData.PackRecords(&addr) + err := extraData.PackRecordsFromProducers(&addr) if err != nil { t.Fatal(err) } var addr2 DeliveryAddress - tlvs, err := extraData.ExtractRecords(&addr2) + tlvs, err := extraData.ExtractRecordsFromProducers(&addr2) if err != nil { t.Fatal(err) } diff --git a/lnwire/typed_lease_expiry_test.go b/lnwire/typed_lease_expiry_test.go index d5f797b200..071a8ae91f 100644 --- a/lnwire/typed_lease_expiry_test.go +++ b/lnwire/typed_lease_expiry_test.go @@ -14,10 +14,10 @@ func TestLeaseExpiryEncodeDecode(t *testing.T) { leaseExpiry := LeaseExpiry(1337) var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecords(&leaseExpiry)) + require.NoError(t, extraData.PackRecordsFromProducers(&leaseExpiry)) var leaseExpiry2 LeaseExpiry - tlvs, err := extraData.ExtractRecords(&leaseExpiry2) + tlvs, err := extraData.ExtractRecordsFromProducers(&leaseExpiry2) require.NoError(t, err) require.Contains(t, tlvs, LeaseExpiryRecordType) From 9142486b9be4dbf932095074f096bd8e9a6185be Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 09:32:34 +0200 Subject: [PATCH 04/33] lnwire: use the RawFeatureVector record methods for ChannelType --- lnwire/accept_channel.go | 18 ++++++++++++---- lnwire/channel_type.go | 43 ------------------------------------- lnwire/channel_type_test.go | 19 +++++++++------- lnwire/features.go | 6 +++--- lnwire/open_channel.go | 20 ++++++++++++----- 5 files changed, 43 insertions(+), 63 deletions(-) diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index 3c99e00231..3c2f0b0fbc 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -142,7 +142,14 @@ var _ Message = (*AcceptChannel)(nil) func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := []tlv.RecordProducer{&a.UpfrontShutdownScript} if a.ChannelType != nil { - recordProducers = append(recordProducers, a.ChannelType) + recordProducers = append(recordProducers, + &RawFeatureVectorRecordProducer{ + RawFeatureVector: RawFeatureVector( + *a.ChannelType, + ), + Type: ChannelTypeRecordType, + }, + ) } if a.LeaseExpiry != nil { recordProducers = append(recordProducers, a.LeaseExpiry) @@ -257,14 +264,16 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { // Next we'll parse out the set of known records, keeping the raw tlv // bytes untouched to ensure we don't drop any bytes erroneously. var ( - chanType ChannelType + chanType = NewRawFeatureVectorRecordProducer( + ChannelTypeRecordType, + ) leaseExpiry LeaseExpiry localNonce = NewMusig2NonceRecordProducer( AcceptChanLocalNonceType, ) ) typeMap, err := tlvRecords.ExtractRecordsFromProducers( - &a.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, + &a.UpfrontShutdownScript, chanType, &leaseExpiry, localNonce, ) if err != nil { return err @@ -272,7 +281,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { // Set the corresponding TLV types if they were included in the stream. if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil { - a.ChannelType = &chanType + channelType := ChannelType(chanType.RawFeatureVector) + a.ChannelType = &channelType } if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { a.LeaseExpiry = &leaseExpiry diff --git a/lnwire/channel_type.go b/lnwire/channel_type.go index a0696048be..3b8eaef781 100644 --- a/lnwire/channel_type.go +++ b/lnwire/channel_type.go @@ -1,8 +1,6 @@ package lnwire import ( - "io" - "github.com/lightningnetwork/lnd/tlv" ) @@ -15,44 +13,3 @@ const ( // ChannelType represents a specific channel type as a set of feature bits that // comprise it. type ChannelType RawFeatureVector - -// featureBitLen returns the length in bytes of the encoded feature bits. -func (c ChannelType) featureBitLen() uint64 { - fv := RawFeatureVector(c) - return uint64(fv.SerializeSize()) -} - -// Record returns a TLV record that can be used to encode/decode the channel -// type from a given TLV stream. -func (c *ChannelType) Record() tlv.Record { - return tlv.MakeDynamicRecord( - ChannelTypeRecordType, c, c.featureBitLen, channelTypeEncoder, - channelTypeDecoder, - ) -} - -// channelTypeEncoder is a custom TLV encoder for the ChannelType record. -func channelTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - if v, ok := val.(*ChannelType); ok { - // Encode the feature bits as a byte slice without its length - // prepended, as that's already taken care of by the TLV record. - fv := RawFeatureVector(*v) - return fv.encode(w, fv.SerializeSize(), 8) - } - - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") -} - -// channelTypeDecoder is a custom TLV decoder for the ChannelType record. -func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - if v, ok := val.(*ChannelType); ok { - fv := NewRawFeatureVector() - if err := fv.decode(r, int(l), 8); err != nil { - return err - } - *v = ChannelType(*fv) - return nil - } - - return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelType") -} diff --git a/lnwire/channel_type_test.go b/lnwire/channel_type_test.go index 1d477151c4..56f07cb432 100644 --- a/lnwire/channel_type_test.go +++ b/lnwire/channel_type_test.go @@ -11,18 +11,21 @@ import ( func TestChannelTypeEncodeDecode(t *testing.T) { t.Parallel() - chanType := ChannelType(*NewRawFeatureVector( - StaticRemoteKeyRequired, - AnchorsZeroFeeHtlcTxRequired, - )) + record1 := RawFeatureVectorRecordProducer{ + RawFeatureVector: *NewRawFeatureVector( + StaticRemoteKeyRequired, + AnchorsZeroFeeHtlcTxRequired, + ), + Type: ChannelTypeRecordType, + } var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecordsFromProducers(&chanType)) + require.NoError(t, extraData.PackRecordsFromProducers(&record1)) - var chanType2 ChannelType - tlvs, err := extraData.ExtractRecordsFromProducers(&chanType2) + record2 := NewRawFeatureVectorRecordProducer(ChannelTypeRecordType) + tlvs, err := extraData.ExtractRecordsFromProducers(record2) require.NoError(t, err) require.Contains(t, tlvs, ChannelTypeRecordType) - require.Equal(t, chanType, chanType2) + require.Equal(t, record1.RawFeatureVector, record2.RawFeatureVector) } diff --git a/lnwire/features.go b/lnwire/features.go index 93d255c3d5..0a2cae5677 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -737,9 +737,9 @@ type RawFeatureVectorRecordProducer struct { Type tlv.Type } -// NewRawFeatureVectorRecord constructs a new RawFeatureVectorRecordProducer -// with the given TLV type. -func NewRawFeatureVectorRecord( +// NewRawFeatureVectorRecordProducer constructs a new +// RawFeatureVectorRecordProducer with the given TLV type. +func NewRawFeatureVectorRecordProducer( tlvType tlv.Type) *RawFeatureVectorRecordProducer { return &RawFeatureVectorRecordProducer{ diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 90e5adbe76..45a095c69f 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -173,10 +173,17 @@ var _ Message = (*OpenChannel)(nil) // Encode serializes the target OpenChannel into the passed io.Writer // implementation. Serialization will observe the rules defined by the passed // protocol version. -func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { +func (o *OpenChannel) Encode(w *bytes.Buffer, _ uint32) error { recordProducers := []tlv.RecordProducer{&o.UpfrontShutdownScript} if o.ChannelType != nil { - recordProducers = append(recordProducers, o.ChannelType) + recordProducers = append(recordProducers, + &RawFeatureVectorRecordProducer{ + RawFeatureVector: RawFeatureVector( + *o.ChannelType, + ), + Type: ChannelTypeRecordType, + }, + ) } if o.LeaseExpiry != nil { recordProducers = append(recordProducers, o.LeaseExpiry) @@ -312,14 +319,16 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { // Next we'll parse out the set of known records, keeping the raw tlv // bytes untouched to ensure we don't drop any bytes erroneously. var ( - chanType ChannelType + chanType = NewRawFeatureVectorRecordProducer( + ChannelTypeRecordType, + ) leaseExpiry LeaseExpiry localNonce = NewMusig2NonceRecordProducer( OpenChanLocalNonceType, ) ) typeMap, err := tlvRecords.ExtractRecordsFromProducers( - &o.UpfrontShutdownScript, &chanType, &leaseExpiry, localNonce, + &o.UpfrontShutdownScript, chanType, &leaseExpiry, localNonce, ) if err != nil { return err @@ -327,7 +336,8 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { // Set the corresponding TLV types if they were included in the stream. if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil { - o.ChannelType = &chanType + channelType := ChannelType(chanType.RawFeatureVector) + o.ChannelType = &channelType } if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { o.LeaseExpiry = &leaseExpiry From 806cd22f69ee8126250de442cc29063fbeb52432 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 09:41:11 +0200 Subject: [PATCH 05/33] lnwire: make ShortChannelID type re-usable --- lnwire/channel_ready.go | 22 ++++++++++++++----- lnwire/short_channel_id.go | 39 ++++++++++++++++++++------------- lnwire/short_channel_id_test.go | 28 +++++++++++++++-------- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index bfd15e08a4..12ecd746e4 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -9,6 +9,11 @@ import ( ) const ( + // ChanReadyAliasScidType is the TLV type of the experimental record of + // channel_ready to denote the alias being used in an option_scid_alias + // channel. + ChanReadyAliasScidType = tlv.Type(1) + // ChanReadyLocalNonceType is the tlv number associated with the local // nonce TLV record in the channel_ready message. ChanReadyLocalNonceType = tlv.Type(4) @@ -82,13 +87,15 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { // Next we'll parse out the set of known records. For now, this is just // the AliasScidRecordType. var ( - aliasScid ShortChannelID + aliasScid = NewShortChannelIDRecordProducer( + ChanReadyAliasScidType, + ) localNonce = NewMusig2NonceRecordProducer( ChanReadyLocalNonceType, ) ) typeMap, err := tlvRecords.ExtractRecordsFromProducers( - &aliasScid, localNonce, + aliasScid, localNonce, ) if err != nil { return err @@ -96,8 +103,8 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { // We'll only set AliasScid if the corresponding TLV type was included // in the stream. - if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { - c.AliasScid = &aliasScid + if val, ok := typeMap[ChanReadyAliasScidType]; ok && val == nil { + c.AliasScid = &aliasScid.ShortChannelID } if val, ok := typeMap[ChanReadyLocalNonceType]; ok && val == nil { c.NextLocalNonce = &localNonce.Musig2Nonce @@ -127,7 +134,12 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { // We'll only encode the AliasScid in a TLV segment if it exists. recordProducers := make([]tlv.RecordProducer, 0, 2) if c.AliasScid != nil { - recordProducers = append(recordProducers, c.AliasScid) + recordProducers = append(recordProducers, + &ShortChannelIDRecordProducer{ + ShortChannelID: *c.AliasScid, + Type: ChanReadyAliasScidType, + }, + ) } if c.NextLocalNonce != nil { recordProducers = append( diff --git a/lnwire/short_channel_id.go b/lnwire/short_channel_id.go index d4da518b76..d286a03221 100644 --- a/lnwire/short_channel_id.go +++ b/lnwire/short_channel_id.go @@ -7,11 +7,28 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -const ( - // AliasScidRecordType is the type of the experimental record to denote - // the alias being used in an option_scid_alias channel. - AliasScidRecordType tlv.Type = 1 -) +// ShortChannelIDRecordProducer wraps a ShortChannelID with the tlv type that it +// should be encoded with. +type ShortChannelIDRecordProducer struct { + ShortChannelID + Type tlv.Type +} + +// Record returns a TLV record that can be used to encode/decode a +// ShortChannelID to/from a TLV stream. +func (c *ShortChannelIDRecordProducer) Record() tlv.Record { + return tlv.MakeStaticRecord( + c.Type, &c.ShortChannelID, 8, EShortChannelID, DShortChannelID, + ) +} + +// NewShortChannelIDRecordProducer constructs a new ShortChannelIDRecordProducer +// with the given TLV type. +func NewShortChannelIDRecordProducer(t tlv.Type) *ShortChannelIDRecordProducer { + return &ShortChannelIDRecordProducer{ + Type: t, + } +} // ShortChannelID represents the set of data which is needed to retrieve all // necessary data to validate the channel existence. @@ -47,8 +64,8 @@ func NewShortChanIDFromInt(chanID uint64) ShortChannelID { // uint64 (8 bytes). func (c ShortChannelID) ToUint64() uint64 { // TODO(roasbeef): explicit error on overflow? - return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) | - (uint64(c.TxPosition))) + return (uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) | + (uint64(c.TxPosition)) } // String generates a human-readable representation of the channel ID. @@ -56,14 +73,6 @@ func (c ShortChannelID) String() string { return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition) } -// Record returns a TLV record that can be used to encode/decode a -// ShortChannelID to/from a TLV stream. -func (c *ShortChannelID) Record() tlv.Record { - return tlv.MakeStaticRecord( - AliasScidRecordType, c, 8, EShortChannelID, DShortChannelID, - ) -} - // IsDefault returns true if the ShortChannelID represents the zero value for // its type. func (c ShortChannelID) IsDefault() bool { diff --git a/lnwire/short_channel_id_test.go b/lnwire/short_channel_id_test.go index c440a0e370..37158b30c1 100644 --- a/lnwire/short_channel_id_test.go +++ b/lnwire/short_channel_id_test.go @@ -46,19 +46,29 @@ func TestShortChannelIDEncoding(t *testing.T) { func TestScidTypeEncodeDecode(t *testing.T) { t.Parallel() - aliasScid := ShortChannelID{ - BlockHeight: (1 << 24) - 1, - TxIndex: (1 << 24) - 1, - TxPosition: (1 << 16) - 1, + aliasScidRecordProducer := &ShortChannelIDRecordProducer{ + ShortChannelID: ShortChannelID{ + BlockHeight: (1 << 24) - 1, + TxIndex: (1 << 24) - 1, + TxPosition: (1 << 16) - 1, + }, + Type: ChanReadyAliasScidType, } var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecordsFromProducers(&aliasScid)) + require.NoError( + t, extraData.PackRecordsFromProducers(aliasScidRecordProducer), + ) + + aliasScid2RecordProducer := NewShortChannelIDRecordProducer( + ChanReadyAliasScidType, + ) - var aliasScid2 ShortChannelID - tlvs, err := extraData.ExtractRecordsFromProducers(&aliasScid2) + tlvs, err := extraData.ExtractRecordsFromProducers( + aliasScid2RecordProducer, + ) require.NoError(t, err) - require.Contains(t, tlvs, AliasScidRecordType) - require.Equal(t, aliasScid, aliasScid2) + require.Contains(t, tlvs, ChanReadyAliasScidType) + require.Equal(t, aliasScidRecordProducer, aliasScid2RecordProducer) } From 76d837d972d6a0f8717a6ba507d9abed55dacf75 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 09:51:27 +0200 Subject: [PATCH 06/33] lnwire: add btc and node announcement nonces to channel_ready --- lnwire/channel_ready.go | 56 +++++++++++++++++++++++++++++++++++++---- lnwire/lnwire_test.go | 20 +++++++-------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 12ecd746e4..028acc796d 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -9,11 +9,19 @@ import ( ) const ( + // ChanReadyAnnounceBtcNonce is the TLV type associated with the + // bitcoin nonce field in the channel_ready message. + ChanReadyAnnounceBtcNonce = tlv.Type(0) + // ChanReadyAliasScidType is the TLV type of the experimental record of // channel_ready to denote the alias being used in an option_scid_alias // channel. ChanReadyAliasScidType = tlv.Type(1) + // ChanReadyAnnounceNodeNonce is the TLV type associated with the node + // nonce field in the channel_ready message. + ChanReadyAnnounceNodeNonce = tlv.Type(2) + // ChanReadyLocalNonceType is the tlv number associated with the local // nonce TLV record in the channel_ready message. ChanReadyLocalNonceType = tlv.Type(4) @@ -38,6 +46,16 @@ type ChannelReady struct { // ShortChannelID for forwarding. AliasScid *ShortChannelID + // AnnouncementBitcoinNonce is an optional field that stores a public + // nonce that will be used along with the node's bitcoin key during + // signing of the ChannelAnnouncement2 message. + AnnouncementBitcoinNonce *Musig2Nonce + + // AnnouncementBitcoinNonce is an optional field that stores a public + // nonce that will be used along with the node's ID key during signing + // of the ChannelAnnouncement2 message. + AnnouncementNodeNonce *Musig2Nonce + // NextLocalNonce is an optional field that stores a local musig2 nonce. // This will only be populated if the simple taproot channels type was // negotiated. This is the local nonce that will be used by the sender @@ -84,8 +102,7 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { return err } - // Next we'll parse out the set of known records. For now, this is just - // the AliasScidRecordType. + // Next we'll parse out the set of known records. var ( aliasScid = NewShortChannelIDRecordProducer( ChanReadyAliasScidType, @@ -93,15 +110,21 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { localNonce = NewMusig2NonceRecordProducer( ChanReadyLocalNonceType, ) + btcNonce = NewMusig2NonceRecordProducer( + ChanReadyAnnounceBtcNonce, + ) + nodeNonce = NewMusig2NonceRecordProducer( + ChanReadyAnnounceNodeNonce, + ) ) typeMap, err := tlvRecords.ExtractRecordsFromProducers( - aliasScid, localNonce, + btcNonce, aliasScid, nodeNonce, localNonce, ) if err != nil { return err } - // We'll only set AliasScid if the corresponding TLV type was included + // We'll only set some fields if the corresponding TLV type was included // in the stream. if val, ok := typeMap[ChanReadyAliasScidType]; ok && val == nil { c.AliasScid = &aliasScid.ShortChannelID @@ -109,6 +132,12 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { if val, ok := typeMap[ChanReadyLocalNonceType]; ok && val == nil { c.NextLocalNonce = &localNonce.Musig2Nonce } + if val, ok := typeMap[ChanReadyAnnounceBtcNonce]; ok && val == nil { + c.AnnouncementBitcoinNonce = &btcNonce.Musig2Nonce + } + if val, ok := typeMap[ChanReadyAnnounceNodeNonce]; ok && val == nil { + c.AnnouncementNodeNonce = &nodeNonce.Musig2Nonce + } if len(tlvRecords) != 0 { c.ExtraData = tlvRecords @@ -131,8 +160,17 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { return err } - // We'll only encode the AliasScid in a TLV segment if it exists. + // We'll only encode the various optional fields in a TLV segment if + // they exists. recordProducers := make([]tlv.RecordProducer, 0, 2) + if c.AnnouncementBitcoinNonce != nil { + recordProducers = append( + recordProducers, &Musig2NonceRecordProducer{ + Type: ChanReadyAnnounceBtcNonce, + Musig2Nonce: *c.AnnouncementBitcoinNonce, + }, + ) + } if c.AliasScid != nil { recordProducers = append(recordProducers, &ShortChannelIDRecordProducer{ @@ -141,6 +179,14 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { }, ) } + if c.AnnouncementNodeNonce != nil { + recordProducers = append( + recordProducers, &Musig2NonceRecordProducer{ + Type: ChanReadyAnnounceNodeNonce, + Musig2Nonce: *c.AnnouncementNodeNonce, + }, + ) + } if c.NextLocalNonce != nil { recordProducers = append( recordProducers, &Musig2NonceRecordProducer{ diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index f5c028581b..6ba572f2fc 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -630,25 +630,23 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { var c [32]byte - if _, err := r.Read(c[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } + _, err := r.Read(c[:]) + require.NoError(t, err) pubKey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } + require.NoError(t, err) - req := NewChannelReady(ChannelID(c), pubKey) + req := NewChannelReady(c, pubKey) if r.Int31()%2 == 0 { - scid := NewShortChanIDFromInt(uint64(r.Int63())) - req.AliasScid = &scid req.NextLocalNonce = randLocalNonce(r) } + if r.Int31()%2 == 0 { + req.AnnouncementBitcoinNonce = randLocalNonce(r) + req.AnnouncementNodeNonce = randLocalNonce(r) + } + v[0] = reflect.ValueOf(*req) }, MsgShutdown: func(v []reflect.Value, r *rand.Rand) { From afda72d814abe88ada7b386909599f025e46d47a Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 10:22:25 +0200 Subject: [PATCH 07/33] lnwire: add AnnouncementSignatures2 message --- lnwire/announcement_signatures_2.go | 76 +++++++++++++++++++++++++++++ lnwire/lnwire.go | 19 ++++++++ lnwire/lnwire_test.go | 35 +++++++++++++ lnwire/message.go | 5 ++ lnwire/writer.go | 8 +++ 5 files changed, 143 insertions(+) create mode 100644 lnwire/announcement_signatures_2.go diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go new file mode 100644 index 0000000000..069dd6ce74 --- /dev/null +++ b/lnwire/announcement_signatures_2.go @@ -0,0 +1,76 @@ +package lnwire + +import ( + "bytes" + "io" +) + +// AnnouncementSignatures2 is a direct message between two endpoints of a +// channel and serves as an opt-in mechanism to allow the announcement of +// a taproot channel to the rest of the network. It contains the necessary +// signatures by the sender to construct the channel_announcement_ message. +type AnnouncementSignatures2 struct { + // ChannelID is the unique description of the funding transaction. + // Channel id is better for users and debugging and short channel id is + // used for quick test on existence of the particular utxo inside the + // blockchain, because it contains information about block. + ChannelID ChannelID + + // ShortChannelID is the unique description of the funding transaction. + // It is constructed with the most significant 3 bytes as the block + // height, the next 3 bytes indicating the transaction index within the + // block, and the least significant two bytes indicating the output + // index which pays to the channel. + ShortChannelID ShortChannelID + + // PartialSignature is the combination of the partial Schnorr signature + // created for the node's bitcoin key with the partial signature created + // for the node's node ID key. + PartialSignature PartialSig + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData ExtraOpaqueData +} + +// A compile time check to ensure AnnouncementSignatures2 implements the +// lnwire.Message interface. +var _ Message = (*AnnouncementSignatures2)(nil) + +// Decode deserializes a serialized AnnounceSignatures stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AnnouncementSignatures2) Decode(r io.Reader, _ uint32) error { + return ReadElements(r, + &a.ChannelID, + &a.ShortChannelID, + &a.PartialSignature, + &a.ExtraOpaqueData, + ) +} + +// Encode serializes the target AnnounceSignatures into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnouncementSignatures2) Encode(w *bytes.Buffer, _ uint32) error { + return WriteElements(w, + a.ChannelID, + a.ShortChannelID, + a.PartialSignature, + a.ExtraOpaqueData, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *AnnouncementSignatures2) MsgType() MessageType { + return MsgAnnouncementSignatures2 +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 46257cce51..79fec5d3b0 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -457,6 +457,12 @@ func WriteElement(w *bytes.Buffer, element interface{}) error { return err } + case PartialSig: + sigBytes := e.Sig.Bytes() + if _, err := w.Write(sigBytes[:]); err != nil { + return err + } + case ExtraOpaqueData: return e.Encode(w) @@ -928,6 +934,19 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = addrBytes[:length] + case *PartialSig: + var sBytes [32]byte + if _, err := io.ReadFull(r, sBytes[:]); err != nil { + return err + } + + var s btcec.ModNScalar + s.SetBytes(&sBytes) + + *e = PartialSig{ + Sig: s, + } + case *ExtraOpaqueData: return e.Decode(r) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 6ba572f2fc..219ba5a32c 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1066,6 +1066,35 @@ func TestLightningWireProtocol(t *testing.T) { PaddingBytes: paddingBytes, } + v[0] = reflect.ValueOf(req) + }, + MsgAnnouncementSignatures2: func(v []reflect.Value, + r *rand.Rand) { + + req := AnnouncementSignatures2{ + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + ExtraOpaqueData: make([]byte, 0), + } + + _, err := r.Read(req.ChannelID[:]) + require.NoError(t, err) + + partialSig, err := randPartialSig(r) + require.NoError(t, err) + + req.PartialSignature = *partialSig + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + v[0] = reflect.ValueOf(req) }, } @@ -1254,6 +1283,12 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgAnnouncementSignatures2, + scenario: func(m AnnouncementSignatures2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index 02447b806c..e14db23f8e 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -46,6 +46,7 @@ const ( MsgNodeAnnouncement = 257 MsgChannelUpdate = 258 MsgAnnounceSignatures = 259 + MsgAnnouncementSignatures2 = 260 MsgQueryShortChanIDs = 261 MsgReplyShortChanIDsEnd = 262 MsgQueryChannelRange = 263 @@ -134,6 +135,8 @@ func (t MessageType) String() string { return "ReplyChannelRange" case MsgGossipTimestampRange: return "GossipTimestampRange" + case MsgAnnouncementSignatures2: + return "MsgAnnouncementSignatures2" default: return "<unknown>" } @@ -236,6 +239,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &ReplyChannelRange{} case MsgGossipTimestampRange: msg = &GossipTimestampRange{} + case MsgAnnouncementSignatures2: + msg = &AnnouncementSignatures2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. diff --git a/lnwire/writer.go b/lnwire/writer.go index 671ebfdc00..c98bd9cbde 100644 --- a/lnwire/writer.go +++ b/lnwire/writer.go @@ -173,6 +173,14 @@ func WriteSigs(buf *bytes.Buffer, sigs []Sig) error { return nil } +// WritePartialSig appends the serialised partial signature to the provided +// buffer. +func WritePartialSig(buf *bytes.Buffer, sig PartialSig) error { + sigBytes := sig.Sig.Bytes() + + return WriteBytes(buf, sigBytes[:]) +} + // WriteFailCode appends the FailCode to the provided buffer. func WriteFailCode(buf *bytes.Buffer, e FailCode) error { return WriteUint16(buf, uint16(e)) From 1f4333492a6cabc037dbdd87e26c9c067be43b77 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 10:33:24 +0200 Subject: [PATCH 08/33] lnwire: add ChannelAnnouncement2 message --- lnwire/channel_announcement_2.go | 268 +++++++++++++++++++++++++++++++ lnwire/extra_bytes.go | 16 ++ lnwire/lnwire_test.go | 132 ++++++++++----- lnwire/message.go | 5 + 4 files changed, 378 insertions(+), 43 deletions(-) create mode 100644 lnwire/channel_announcement_2.go diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go new file mode 100644 index 0000000000..2fcf05bfd6 --- /dev/null +++ b/lnwire/channel_announcement_2.go @@ -0,0 +1,268 @@ +package lnwire + +import ( + "bytes" + "io" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // ChanAnn2ChainHashType is the tlv number associated with the chain + // hash TLV record in the channel_announcement_2 message. + ChanAnn2ChainHashType = tlv.Type(0) + + // ChanAnn2FeaturesType is the tlv number associated with the features + // TLV record in the channel_announcement_2 message. + ChanAnn2FeaturesType = tlv.Type(2) + + // ChanAnn2SCIDType is the tlv number associated with the SCID TLV + // record in the channel_announcement_2 message. + ChanAnn2SCIDType = tlv.Type(4) + + // ChanAnn2CapacityType is the tlv number associated with the capacity + // TLV record in the channel_announcement_2 message. + ChanAnn2CapacityType = tlv.Type(6) + + // ChanAnn2NodeID1Type is the tlv number associated with the node ID 1 + // TLV record in the channel_announcement_2 message. + ChanAnn2NodeID1Type = tlv.Type(8) + + // ChanAnn2NodeID2Type is the tlv number associated with the node ID 2 + // record in the channel_announcement_2 message. + ChanAnn2NodeID2Type = tlv.Type(10) + + // ChanAnn2BtcKey1Type is the tlv number associated with the bitcoin ID + // 1 record in the channel_announcement_2 message. + ChanAnn2BtcKey1Type = tlv.Type(12) + + // ChanAnn2BtcKey2Type is the tlv number associated with the bitcoin ID + // 2 record in the channel_announcement_2 message. + ChanAnn2BtcKey2Type = tlv.Type(14) + + // ChanAnn2MerkleRootHashType is the tlv number associated with the + // merkle root hash record in the channel_announcement_2 message. + ChanAnn2MerkleRootHashType = tlv.Type(16) +) + +// ChannelAnnouncement2 message is used to announce the existence of a taproot +// channel between two peers in the network. +type ChannelAnnouncement2 struct { + // Signature is a Schnorr signature over the TLV stream of the message. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + ChainHash chainhash.Hash + + // Features is the feature vector that encodes the features supported + // by the target node. This field can be used to signal the type of the + // channel, or modifications to the fields that would normally follow + // this vector. + Features RawFeatureVector + + // ShortChannelID is the unique description of the funding transaction, + // or where exactly it's located within the target blockchain. + ShortChannelID ShortChannelID + + // Capacity is the number of satoshis of the capacity of this channel. + // It must be less than or equal to the value of the on-chain funding + // output. + Capacity uint64 + + // NodeID1 is the numerically-lesser public key ID of one of the channel + // operators. + NodeID1 [33]byte + + // NodeID2 is the numerically-greater public key ID of one of the + // channel operators. + NodeID2 [33]byte + + // BitcoinKey1 is the public key of the key used by Node1 in the + // construction of the on-chain funding transaction. This is an optional + // field and only needs to be set if the 4-of-4 MuSig construction was + // used in the creation of the message signature. + BitcoinKey1 *[33]byte + + // BitcoinKey2 is the public key of the key used by Node2 in the + // construction of the on-chain funding transaction. This is an optional + // field and only needs to be set if the 4-of-4 MuSig construction was + // used in the creation of the message signature. + BitcoinKey2 *[33]byte + + // MerkleRootHash is the hash used to create the optional tweak in the + // funding output. If this is not set but the bitcoin keys are, then + // the funding output is a pure 2-of-2 MuSig aggregate public key. + MerkleRootHash *[32]byte + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData ExtraOpaqueData +} + +// Decode deserializes a serialized AnnounceSignatures stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &c.Signature) + if err != nil { + return err + } + c.Signature.ForceSchnorr() + + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + featuresRecordProducer := NewRawFeatureVectorRecordProducer( + ChanAnn2FeaturesType, + ) + + scidRecordProducer := NewShortChannelIDRecordProducer( + ChanAnn2SCIDType, + ) + + var ( + chainHash, merkleRootHash [32]byte + btcKey1, btcKey2 [33]byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(ChanAnn2ChainHashType, &chainHash), + featuresRecordProducer.Record(), + scidRecordProducer.Record(), + tlv.MakePrimitiveRecord(ChanAnn2CapacityType, &c.Capacity), + tlv.MakePrimitiveRecord(ChanAnn2NodeID1Type, &c.NodeID1), + tlv.MakePrimitiveRecord(ChanAnn2NodeID2Type, &c.NodeID2), + tlv.MakePrimitiveRecord(ChanAnn2BtcKey1Type, &btcKey1), + tlv.MakePrimitiveRecord(ChanAnn2BtcKey2Type, &btcKey2), + tlv.MakePrimitiveRecord( + ChanAnn2MerkleRootHashType, &merkleRootHash, + ), + } + + typeMap, err := tlvRecords.ExtractRecords(records...) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[ChanAnn2ChainHashType]; ok { + c.ChainHash = chainHash + } + + if _, ok := typeMap[ChanAnn2FeaturesType]; ok { + c.Features = featuresRecordProducer.RawFeatureVector + } + + if _, ok := typeMap[ChanAnn2SCIDType]; ok { + c.ShortChannelID = scidRecordProducer.ShortChannelID + } + + if _, ok := typeMap[ChanAnn2BtcKey1Type]; ok { + c.BitcoinKey1 = &btcKey1 + } + + if _, ok := typeMap[ChanAnn2BtcKey2Type]; ok { + c.BitcoinKey2 = &btcKey2 + } + + if _, ok := typeMap[ChanAnn2MerkleRootHashType]; ok { + c.MerkleRootHash = &merkleRootHash + } + + if len(tlvRecords) != 0 { + c.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target AnnounceSignatures into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + var records []tlv.Record + + _, err := w.Write(c.Signature.RawBytes()) + if err != nil { + return err + } + + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + if !c.ChainHash.IsEqual(chaincfg.MainNetParams.GenesisHash) { + chainHash := [32]byte(c.ChainHash) + records = append(records, tlv.MakePrimitiveRecord( + ChanAnn2ChainHashType, &chainHash, + )) + } + + featuresRecordProducer := &RawFeatureVectorRecordProducer{ + RawFeatureVector: c.Features, + Type: ChanAnn2FeaturesType, + } + + scidRecordProducer := &ShortChannelIDRecordProducer{ + ShortChannelID: c.ShortChannelID, + Type: ChanAnn2SCIDType, + } + + records = append(records, + featuresRecordProducer.Record(), + scidRecordProducer.Record(), + tlv.MakePrimitiveRecord(ChanAnn2CapacityType, &c.Capacity), + tlv.MakePrimitiveRecord(ChanAnn2NodeID1Type, &c.NodeID1), + tlv.MakePrimitiveRecord(ChanAnn2NodeID2Type, &c.NodeID2), + ) + + if c.BitcoinKey1 != nil && c.BitcoinKey2 != nil { + records = append(records, + tlv.MakePrimitiveRecord( + ChanAnn2BtcKey1Type, c.BitcoinKey1, + ), + tlv.MakePrimitiveRecord( + ChanAnn2BtcKey2Type, c.BitcoinKey2, + ), + ) + + if c.MerkleRootHash != nil { + records = append(records, + tlv.MakePrimitiveRecord( + ChanAnn2MerkleRootHashType, + c.MerkleRootHash, + ), + ) + } + } + + err = EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) MsgType() MessageType { + return MsgChannelAnnouncement2 +} + +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.Message interface. +var _ Message = (*ChannelAnnouncement2)(nil) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 8f5668efa0..b2a12a4602 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -139,3 +139,19 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData, // are all properly sorted. return extraData.PackRecordsFromProducers(recordProducers...) } + +// EncodeMessageExtraDataFromRecords encodes the given records into the given +// extraData. +func EncodeMessageExtraDataFromRecords(extraData *ExtraOpaqueData, + records ...tlv.Record) error { + + // Treat extraData as a mutable reference. + if extraData == nil { + return fmt.Errorf("extra data cannot be nil") + } + + // Pack in the series of TLV records into this message. The order we + // pass them in doesn't matter, as the method will ensure that things + // are all properly sorted. + return extraData.PackRecords(records...) +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 219ba5a32c..8f03d4b445 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -18,6 +18,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/tor" @@ -26,17 +27,25 @@ import ( ) var ( - shaHash1Bytes, _ = hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") - shaHash1, _ = chainhash.NewHash(shaHash1Bytes) - outpoint1 = wire.NewOutPoint(shaHash1, 0) - - testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") - testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") - testRScalar = new(btcec.ModNScalar) - testSScalar = new(btcec.ModNScalar) - _ = testRScalar.SetByteSlice(testRBytes) - _ = testSScalar.SetByteSlice(testSBytes) - testSig = ecdsa.NewSignature(testRScalar, testSScalar) + shaHash1Bytes, _ = hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb" + + "92427ae41e4649b934ca495991b7852b855") + + shaHash1, _ = chainhash.NewHash(shaHash1Bytes) + outpoint1 = wire.NewOutPoint(shaHash1, 0) + + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571" + + "319d18e949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a" + + "88121167221b6700d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testSchnorrSigStr, _ = hex.DecodeString("04E7F9037658A92AFEB4F2" + + "5BAE5339E3DDCA81A353493827D26F16D92308E49E2A25E9220867" + + "8A2DF86970DA91B03A8AF8815A8A60498B358DAF560B347AA557") + testSchnorrSig, _ = NewSigFromSchnorrRawSignature(testSchnorrSigStr) ) const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -95,17 +104,15 @@ func randPubKey() (*btcec.PublicKey, error) { return priv.PubKey(), nil } -func randRawKey() ([33]byte, error) { +func randRawKey(t *testing.T) [33]byte { var n [33]byte priv, err := btcec.NewPrivateKey() - if err != nil { - return n, err - } + require.NoError(t, err) copy(n[:], priv.PubKey().SerializeCompressed()) - return n, nil + return n } func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { @@ -774,7 +781,13 @@ func TestLightningWireProtocol(t *testing.T) { MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := ChannelAnnouncement{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + NodeID1: randRawKey(t), + NodeID2: randRawKey(t), + BitcoinKey1: randRawKey(t), + BitcoinKey2: randRawKey(t), Features: randRawFeatureVector(r), ExtraOpaqueData: make([]byte, 0), } @@ -799,26 +812,6 @@ func TestLightningWireProtocol(t *testing.T) { return } - req.NodeID1, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.NodeID2, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.BitcoinKey1, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.BitcoinKey2, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } if _, err := r.Read(req.ChainHash[:]); err != nil { t.Fatalf("unable to generate chain hash: %v", err) return @@ -840,6 +833,7 @@ func TestLightningWireProtocol(t *testing.T) { MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := NodeAnnouncement{ + NodeID: randRawKey(t), Features: randRawFeatureVector(r), Timestamp: uint32(r.Int31()), Alias: randAlias(r), @@ -856,12 +850,6 @@ func TestLightningWireProtocol(t *testing.T) { return } - req.NodeID, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.Addresses, err = randAddrs(r) if err != nil { t.Fatalf("unable to generate addresses: %v", err) @@ -1095,6 +1083,58 @@ func TestLightningWireProtocol(t *testing.T) { require.NoError(t, err) } + v[0] = reflect.ValueOf(req) + }, + MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { + req := ChannelAnnouncement2{ + Signature: testSchnorrSig, + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + Capacity: rand.Uint64(), + NodeID1: randRawKey(t), + NodeID2: randRawKey(t), + ExtraOpaqueData: make([]byte, 0), + } + + features := randRawFeatureVector(r) + req.Features = *features + + // Sometimes set chain hash to bitcoin mainnet genesis + // hash. + req.ChainHash = *chaincfg.MainNetParams.GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(req.ChainHash[:]) + require.NoError(t, err) + } + + // Sometimes set the bitcoin keys. + if r.Int31()%2 == 0 { + btcKey1 := randRawKey(t) + req.BitcoinKey1 = &btcKey1 + + btcKey2 := randRawKey(t) + req.BitcoinKey2 = &btcKey2 + + // Occasionally also set the merkle root hash. + if r.Int31()%2 == 0 { + var merkleRootHash [32]byte + _, err := r.Read(merkleRootHash[:]) + require.NoError(t, err) + + req.MerkleRootHash = &merkleRootHash + } + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + v[0] = reflect.ValueOf(req) }, } @@ -1289,6 +1329,12 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgChannelAnnouncement2, + scenario: func(m ChannelAnnouncement2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index e14db23f8e..116693b774 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -52,6 +52,7 @@ const ( MsgQueryChannelRange = 263 MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 + MsgChannelAnnouncement2 = 267 ) // ErrorEncodeMessage is used when failed to encode the message payload. @@ -137,6 +138,8 @@ func (t MessageType) String() string { return "GossipTimestampRange" case MsgAnnouncementSignatures2: return "MsgAnnouncementSignatures2" + case MsgChannelAnnouncement2: + return "ChannelAnnouncement2" default: return "<unknown>" } @@ -241,6 +244,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &GossipTimestampRange{} case MsgAnnouncementSignatures2: msg = &AnnouncementSignatures2{} + case MsgChannelAnnouncement2: + msg = &ChannelAnnouncement2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. From 457cca4a72caa6eb44c6fe5b73e19285a4fdeb4b Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 11:13:53 +0200 Subject: [PATCH 09/33] lnwire: introduce the BooleanRecordProducer --- lnwire/boolean.go | 53 ++++++++++++++++++++++++++++++++++++++++++ lnwire/boolean_test.go | 40 +++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 lnwire/boolean.go create mode 100644 lnwire/boolean_test.go diff --git a/lnwire/boolean.go b/lnwire/boolean.go new file mode 100644 index 0000000000..00a474f357 --- /dev/null +++ b/lnwire/boolean.go @@ -0,0 +1,53 @@ +package lnwire + +import ( + "fmt" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// BooleanRecordProducer wraps a boolean with the tlv type it should be encoded +// with. A boolean is by default false, if the TLV is present then the boolean +// is true. +type BooleanRecordProducer struct { + Bool bool + Type tlv.Type +} + +// Record returns the tlv record for the boolean entry. +func (b *BooleanRecordProducer) Record() tlv.Record { + return tlv.MakeStaticRecord( + b.Type, &b.Bool, 0, booleanEncoder, booleanDecoder, + ) +} + +// NewBooleanRecordProducer constructs a new BooleanRecordProducer. +func NewBooleanRecordProducer(t tlv.Type) *BooleanRecordProducer { + return &BooleanRecordProducer{ + Type: t, + } +} + +func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*bool); ok { + if !*v { + return fmt.Errorf("a boolean record should only be " + + "encoded if the value of the boolean is true") + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "bool") +} + +func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte, _ uint64) error { + if v, ok := val.(*bool); ok { + *v = true + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "bool") +} diff --git a/lnwire/boolean_test.go b/lnwire/boolean_test.go new file mode 100644 index 0000000000..6aa0565de4 --- /dev/null +++ b/lnwire/boolean_test.go @@ -0,0 +1,40 @@ +package lnwire + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestBooleanRecord tests the encoding and decoding of a boolean tlv record. +func TestBooleanRecord(t *testing.T) { + t.Parallel() + + const recordType = tlv.Type(0) + + b1 := BooleanRecordProducer{ + Bool: false, + Type: recordType, + } + + var extraData ExtraOpaqueData + + // A false boolean should not be encoded. + require.ErrorContains(t, extraData.PackRecordsFromProducers(&b1), + "a boolean record should only be encoded if the value of "+ + "the boolean is true") + + b2 := BooleanRecordProducer{ + Bool: true, + Type: recordType, + } + require.NoError(t, extraData.PackRecordsFromProducers(&b2)) + + b3 := NewBooleanRecordProducer(recordType) + tlvs, err := extraData.ExtractRecordsFromProducers(b3) + require.NoError(t, err) + + require.Contains(t, tlvs, recordType) + require.Equal(t, b2.Bool, b3.Bool) +} From 69fb24d46edc616bb0a96c960bd1708b3ee74e21 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 11:37:39 +0200 Subject: [PATCH 10/33] lnwire: add ChannelUpdate2 --- lnwire/channel_update_2.go | 375 +++++++++++++++++++++++++++++++++++++ lnwire/lnwire_test.go | 68 +++++++ lnwire/message.go | 5 + 3 files changed, 448 insertions(+) create mode 100644 lnwire/channel_update_2.go diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go new file mode 100644 index 0000000000..88452dd9a5 --- /dev/null +++ b/lnwire/channel_update_2.go @@ -0,0 +1,375 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // ChanUpdate2ChainHashType is the tlv number associated with the chain + // hash TLV record in the channel_update_2 message. + ChanUpdate2ChainHashType = tlv.Type(0) + + // ChanUpdate2SCIDType is the tlv number associated with the SCID TLV + // record in the channel_update_2 message. + ChanUpdate2SCIDType = tlv.Type(2) + + // ChanUpdate2BlockHeightType is the tlv number associated with the + // block height record in the channel_update_2 message. + ChanUpdate2BlockHeightType = tlv.Type(4) + + // ChanUpdate2DisableFlagsType is the tlv number associated with the + // disable flags record in the channel_update_2 message. + ChanUpdate2DisableFlagsType = tlv.Type(6) + + // ChanUpdate2DirectionType is the tlv number associated with the + // disable boolean TLV record in the channel_update_2 message. + ChanUpdate2DirectionType = tlv.Type(8) + + // ChanUpdate2CLTVExpiryDeltaType is the tlv number associated with the + // CLTV expiry delta TLV record in the channel_update_2 message. + ChanUpdate2CLTVExpiryDeltaType = tlv.Type(10) + + // ChanUpdate2HTLCMinMsatType is the tlv number associated with the htlc + // minimum msat record in the channel_update_2 message. + ChanUpdate2HTLCMinMsatType = tlv.Type(12) + + // ChanUpdate2HTLCMaxMsatType is the tlv number associated with the htlc + // maximum msat record in the channel_update_2 message. + ChanUpdate2HTLCMaxMsatType = tlv.Type(14) + + // ChanUpdate2FeeBaseMsatType is the tlv number associated with the fee + // base msat record in the channel_update_2 message. + ChanUpdate2FeeBaseMsatType = tlv.Type(16) + + // ChanUpdate2FeeProportionalMillionthsType is the tlv number associated + // with the fee proportional millionths record in the channel_update_2 + // message. + ChanUpdate2FeeProportionalMillionthsType = tlv.Type(18) + + defaultCltvExpiryDelta = uint16(80) + defaultHtlcMinMsat = MilliSatoshi(1) + defaultFeeBaseMsat = uint32(1000) + defaultFeeProportionalMillionths = uint32(1) +) + +// ChannelUpdate2 message is used after taproot channel has been initially +// announced. Each side independently announces its fees and minimum expiry for +// HTLCs and other parameters. Also this message is used to redeclare initially +// set channel parameters. +type ChannelUpdate2 struct { + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + // Along with the short channel ID, this uniquely identifies the + // channel globally in a blockchain. + ChainHash chainhash.Hash + + // ShortChannelID is the unique description of the funding transaction. + ShortChannelID ShortChannelID + + // BlockHeight allows ordering in the case of multiple announcements. We + // should ignore the message if block height is not greater than the + // last-received. The block height must always be greater or equal to + // the block height that the channel funding transaction was confirmed + // in. + BlockHeight uint32 + + // DisabledFlags is an optional bitfield that describes various reasons + // that the node is communicating that the channel should be considered + // disabled. + DisabledFlags ChanUpdateDisableFlags + + // Direction is false if this update was produced by node 1 of the + // channel announcement and true if it is from node 2. + Direction bool + + // CLTVExpiryDelta is the minimum number of blocks this node requires to + // be added to the expiry of HTLCs. This is a security parameter + // determined by the node operator. This value represents the required + // gap between the time locks of the incoming and outgoing HTLC's set + // to this node. + CLTVExpiryDelta uint16 + + // HTLCMinimumMsat is the minimum HTLC value which will be accepted. + HTLCMinimumMsat MilliSatoshi + + // HtlcMaximumMsat is the maximum HTLC value which will be accepted. + HTLCMaximumMsat MilliSatoshi + + // FeeBaseMsat is the base fee that must be used for incoming HTLC's to + // this particular channel. This value will be tacked onto the required + // for a payment independent of the size of the payment. + FeeBaseMsat uint32 + + // FeeProportionalMillionths is the fee rate that will be charged per + // millionth of a satoshi. + FeeProportionalMillionths uint32 + + // ExtraOpaqueData is the set of data that was appended to this message + // to fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData +} + +// Decode deserializes a serialized AnnounceSignatures stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &c.Signature) + if err != nil { + return err + } + c.Signature.ForceSchnorr() + + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + scidRecordProducer := NewShortChannelIDRecordProducer( + ChanUpdate2SCIDType, + ) + + directionRecordProducer := NewBooleanRecordProducer( + ChanUpdate2DirectionType, + ) + + var ( + chainHash [32]byte + htlcMin, htlcMax uint64 + disableFlags uint8 + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(ChanUpdate2ChainHashType, &chainHash), + scidRecordProducer.Record(), + tlv.MakePrimitiveRecord( + ChanUpdate2BlockHeightType, &c.BlockHeight, + ), + tlv.MakePrimitiveRecord( + ChanUpdate2DisableFlagsType, &disableFlags, + ), + directionRecordProducer.Record(), + tlv.MakePrimitiveRecord( + ChanUpdate2CLTVExpiryDeltaType, &c.CLTVExpiryDelta, + ), + tlv.MakePrimitiveRecord( + ChanUpdate2HTLCMinMsatType, &htlcMin, + ), + tlv.MakePrimitiveRecord( + ChanUpdate2HTLCMaxMsatType, &htlcMax, + ), + tlv.MakePrimitiveRecord( + ChanUpdate2FeeBaseMsatType, &c.FeeBaseMsat, + ), + tlv.MakePrimitiveRecord( + ChanUpdate2FeeProportionalMillionthsType, + &c.FeeProportionalMillionths, + ), + } + + typeMap, err := tlvRecords.ExtractRecords(records...) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[ChanUpdate2ChainHashType]; ok { + c.ChainHash = chainHash + } + + if _, ok := typeMap[ChanUpdate2DisableFlagsType]; ok { + c.DisabledFlags = ChanUpdateDisableFlags(disableFlags) + } + + if _, ok := typeMap[ChanUpdate2SCIDType]; ok { + c.ShortChannelID = scidRecordProducer.ShortChannelID + } + + if _, ok := typeMap[ChanUpdate2DirectionType]; ok { + c.Direction = directionRecordProducer.Bool + } + + // If the CLTV expiry delta was not encoded, then set it to the default + // value. + if _, ok := typeMap[ChanUpdate2CLTVExpiryDeltaType]; !ok { + c.CLTVExpiryDelta = defaultCltvExpiryDelta + } + + c.HTLCMinimumMsat = defaultHtlcMinMsat + if _, ok := typeMap[ChanUpdate2HTLCMinMsatType]; ok { + c.HTLCMinimumMsat = MilliSatoshi(htlcMin) + } + + if _, ok := typeMap[ChanUpdate2HTLCMaxMsatType]; ok { + c.HTLCMaximumMsat = MilliSatoshi(htlcMax) + } + + // If the base fee was not encoded, then set it to the default value. + if _, ok := typeMap[ChanUpdate2FeeBaseMsatType]; !ok { + c.FeeBaseMsat = defaultFeeBaseMsat + } + + // If the proportional fee was not encoded, then set it to the default + // value. + if _, ok := typeMap[ChanUpdate2FeeProportionalMillionthsType]; !ok { + c.FeeProportionalMillionths = defaultFeeProportionalMillionths + } + + if len(tlvRecords) != 0 { + c.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target AnnounceSignatures into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + _, err := w.Write(c.Signature.RawBytes()) + if err != nil { + return err + } + + var records []tlv.Record + + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + if !c.ChainHash.IsEqual(chaincfg.MainNetParams.GenesisHash) { + chainHash := [32]byte(c.ChainHash) + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2ChainHashType, &chainHash, + )) + } + + scidRecordProducer := &ShortChannelIDRecordProducer{ + ShortChannelID: c.ShortChannelID, + Type: ChanUpdate2SCIDType, + } + + records = append(records, + scidRecordProducer.Record(), + tlv.MakePrimitiveRecord( + ChanUpdate2BlockHeightType, &c.BlockHeight, + ), + ) + + // Only include the disable flags if any bit is set. + if !c.DisabledFlags.IsEnabled() { + disableFlags := uint8(c.DisabledFlags) + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2DisableFlagsType, &disableFlags, + )) + } + + // We only need to encode the direction if the direction is set to 1. + if c.Direction { + directionRecordProducer := &BooleanRecordProducer{ + Bool: true, + Type: ChanUpdate2DirectionType, + } + records = append(records, directionRecordProducer.Record()) + } + + // We only encode the cltv expiry delta if it is not equal to the + // default. + if c.CLTVExpiryDelta != defaultCltvExpiryDelta { + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2CLTVExpiryDeltaType, &c.CLTVExpiryDelta, + )) + } + + if c.HTLCMinimumMsat != defaultHtlcMinMsat { + var htlcMin = uint64(c.HTLCMinimumMsat) + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2HTLCMinMsatType, &htlcMin, + )) + } + + var htlcMax = uint64(c.HTLCMaximumMsat) + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2HTLCMaxMsatType, &htlcMax, + )) + + if c.FeeBaseMsat != defaultFeeBaseMsat { + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2FeeBaseMsatType, &c.FeeBaseMsat, + )) + } + + if c.FeeProportionalMillionths != defaultFeeProportionalMillionths { + records = append(records, tlv.MakePrimitiveRecord( + ChanUpdate2FeeProportionalMillionthsType, + &c.FeeProportionalMillionths, + )) + } + + err = EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) MsgType() MessageType { + return MsgChannelUpdate2 +} + +// A compile time check to ensure ChannelUpdate2 implements the lnwire.Message +// interface. +var _ Message = (*ChannelUpdate2)(nil) + +// ChanUpdateDisableFlags is a bit vector that can be used to indicate various +// reasons for the channel being marked as disabled. +type ChanUpdateDisableFlags uint8 + +const ( + // ChanUpdateDisableIncoming is a bit indicates that a channel is + // disabled in the inbound direction meaning that the node broadcasting + // the update is communicating that they cannot receive funds. + ChanUpdateDisableIncoming ChanUpdateDisableFlags = 1 << iota + + // ChanUpdateDisableOutgoing is a bit indicates that a channel is + // disabled in the outbound direction meaning that the node broadcasting + // the update is communicating that they cannot send or route funds. + ChanUpdateDisableOutgoing = 2 +) + +// IncomingDisabled returns true if the ChanUpdateDisableIncoming bit is set. +func (c ChanUpdateDisableFlags) IncomingDisabled() bool { + return c&ChanUpdateDisableIncoming == ChanUpdateDisableIncoming +} + +// OutgoingDisabled returns true if the ChanUpdateDisableOutgoing bit is set. +func (c ChanUpdateDisableFlags) OutgoingDisabled() bool { + return c&ChanUpdateDisableOutgoing == ChanUpdateDisableOutgoing +} + +// IsEnabled returns true if none of the disable bits are set. +func (c ChanUpdateDisableFlags) IsEnabled() bool { + return c == 0 +} + +// String returns the bitfield flags as a string. +func (c ChanUpdateDisableFlags) String() string { + return fmt.Sprintf("%08b", c) +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 8f03d4b445..c2ee1d41b1 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1137,6 +1137,74 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, + MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { + req := ChannelUpdate2{ + Signature: testSchnorrSig, + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + BlockHeight: r.Uint32(), + HTLCMaximumMsat: MilliSatoshi(r.Uint64()), + ExtraOpaqueData: make([]byte, 0), + } + + // Sometimes set chain hash to bitcoin mainnet genesis + // hash. + req.ChainHash = *chaincfg.MainNetParams.GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(req.ChainHash[:]) + require.NoError(t, err) + } + + // Sometimes use default htlc min msat. + req.HTLCMinimumMsat = defaultHtlcMinMsat + if r.Int31()%2 == 0 { + req.HTLCMinimumMsat = MilliSatoshi(r.Uint64()) + } + + // Sometimes set the cltv expiry delta to the default. + req.CLTVExpiryDelta = defaultCltvExpiryDelta + if r.Int31()%2 == 0 { + req.CLTVExpiryDelta = uint16(r.Int31()) + } + + // Sometimes use default fee base. + req.FeeBaseMsat = defaultFeeBaseMsat + if r.Int31()%2 == 0 { + req.FeeBaseMsat = r.Uint32() + } + + // Sometimes use default proportional fee. + req.FeeProportionalMillionths = + defaultFeeProportionalMillionths + if r.Int31()%2 == 0 { + req.FeeProportionalMillionths = r.Uint32() + } + + // Alternate between the two direction possibilities. + if r.Int31()%2 == 0 { + req.Direction = true + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags |= ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags |= ChanUpdateDisableOutgoing + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + }, } // With the above types defined, we'll now generate a slice of diff --git a/lnwire/message.go b/lnwire/message.go index 116693b774..db9b23ab87 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -53,6 +53,7 @@ const ( MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 MsgChannelAnnouncement2 = 267 + MsgChannelUpdate2 = 271 ) // ErrorEncodeMessage is used when failed to encode the message payload. @@ -140,6 +141,8 @@ func (t MessageType) String() string { return "MsgAnnouncementSignatures2" case MsgChannelAnnouncement2: return "ChannelAnnouncement2" + case MsgChannelUpdate2: + return "ChannelUpdate2" default: return "<unknown>" } @@ -246,6 +249,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnouncementSignatures2{} case MsgChannelAnnouncement2: msg = &ChannelAnnouncement2{} + case MsgChannelUpdate2: + msg = &ChannelUpdate2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. From efff6f283950cb3efa2c7d144ed0bdddee47fc6e Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 29 Sep 2023 11:56:58 +0200 Subject: [PATCH 11/33] lnwire: add NodeAnnouncement2 --- lnwire/lnwire_test.go | 73 +++++ lnwire/message.go | 5 + lnwire/node_announcement_2.go | 554 ++++++++++++++++++++++++++++++++++ 3 files changed, 632 insertions(+) create mode 100644 lnwire/node_announcement_2.go diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index c2ee1d41b1..21a12ede51 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "encoding/base64" "encoding/binary" "encoding/hex" "fmt" @@ -1205,6 +1206,72 @@ func TestLightningWireProtocol(t *testing.T) { require.NoError(t, err) } }, + MsgNodeAnnouncement2: func(v []reflect.Value, r *rand.Rand) { + req := NodeAnnouncement2{ + Signature: testSchnorrSig, + BlockHeight: r.Uint32(), + NodeID: randRawKey(t), + ExtraOpaqueData: make([]byte, 0), + } + + features := randRawFeatureVector(r) + req.Features = *features + + // Sometimes set the colour field. + if r.Int31()%2 == 0 { + req.RGBColor = &color.RGBA{ + R: uint8(r.Int31()), + G: uint8(r.Int31()), + B: uint8(r.Int31()), + } + } + + n := r.Intn(33) + b := make([]byte, n) + _, err := rand.Read(b) + require.NoError(t, err) + + if n > 0 { + req.Alias = []byte( + base64.StdEncoding.EncodeToString(b), + ) + if len(req.Alias) > 32 { + req.Alias = req.Alias[:32] + } + } + + // Sometimes add some ipv4 addrs. + if r.Int31()%2 == 0 { + ipv4Addr, err := randTCP4Addr(r) + require.NoError(t, err) + req.Addresses = append(req.Addresses, ipv4Addr) + } + + // Sometimes add some ipv6 addrs. + if r.Int31()%2 == 0 { + ipv6Addr, err := randTCP6Addr(r) + require.NoError(t, err) + req.Addresses = append(req.Addresses, ipv6Addr) + } + + // Sometimes add some torv3 addrs. + if r.Int31()%2 == 0 { + ipv6Addr, err := randV3OnionAddr(r) + require.NoError(t, err) + req.Addresses = append(req.Addresses, ipv6Addr) + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(req) + }, } // With the above types defined, we'll now generate a slice of @@ -1403,6 +1470,12 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgNodeAnnouncement2, + scenario: func(m NodeAnnouncement2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index db9b23ab87..cebc7f4020 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -53,6 +53,7 @@ const ( MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 MsgChannelAnnouncement2 = 267 + MsgNodeAnnouncement2 = 269 MsgChannelUpdate2 = 271 ) @@ -143,6 +144,8 @@ func (t MessageType) String() string { return "ChannelAnnouncement2" case MsgChannelUpdate2: return "ChannelUpdate2" + case MsgNodeAnnouncement2: + return "NodeAnnouncement2" default: return "<unknown>" } @@ -251,6 +254,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &ChannelAnnouncement2{} case MsgChannelUpdate2: msg = &ChannelUpdate2{} + case MsgNodeAnnouncement2: + msg = &NodeAnnouncement2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. diff --git a/lnwire/node_announcement_2.go b/lnwire/node_announcement_2.go new file mode 100644 index 0000000000..2ecb3e22ba --- /dev/null +++ b/lnwire/node_announcement_2.go @@ -0,0 +1,554 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "image/color" + "io" + "net" + "unicode/utf8" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" +) + +const ( + // NodeAnn2FeaturesType is the tlv number associated with the features + // vector TLV record in the node_announcement_2 message. + NodeAnn2FeaturesType = tlv.Type(0) + + // NodeAnn2RGBColorType is the tlv number associated with the color TLV + // record in the node_announcement_2 message. + NodeAnn2RGBColorType = tlv.Type(1) + + // NodeAnn2BlockHeightType is the tlv number associated with the block + // height TLV record in the node_announcement_2 message. + NodeAnn2BlockHeightType = tlv.Type(2) + + // NodeAnn2AliasType is the tlv number associated with the alias vector + // TLV record in the node_announcement_2 message. + NodeAnn2AliasType = tlv.Type(3) + + // NodeAnn2NodeIDType is the tlv number associated with the node ID TLV + // record in the node_announcement_2 message. + NodeAnn2NodeIDType = tlv.Type(4) + + // NodeAnn2IPV4AddrsType is the tlv number associated with the ipv4 + // addresses TLV record in the node_announcement_2 message. + NodeAnn2IPV4AddrsType = tlv.Type(5) + + // NodeAnn2IPV6AddrsType is the tlv number associated with the ipv6 + // addresses TLV record in the node_announcement_2 message. + NodeAnn2IPV6AddrsType = tlv.Type(7) + + // NodeAnn2TorV3AddrsType is the tlv number associated with the tor V3 + // addresses TLV record in the node_announcement_2 message. + NodeAnn2TorV3AddrsType = tlv.Type(9) +) + +// NodeAnnouncement2 message is used to announce the presence of a Lightning +// node and also to signal that the node is accepting incoming connections. +// Each NodeAnnouncement authenticating the advertised information within the +// announcement via a signature using the advertised node pubkey. +type NodeAnnouncement2 struct { + // Signature is used to prove the ownership of node id. + Signature Sig + + // Features is the list of protocol features this node supports. + Features RawFeatureVector + + // RGBColor is an optional field used to customize a node's appearance + // in maps and graphs. + RGBColor *color.RGBA + + // BlockHeight allows ordering in the case of multiple announcements. + BlockHeight uint32 + + // Alias is used to customize their node's appearance in maps and + // graphs. + Alias []byte + + // NodeID is a public key which is used as node identification. + NodeID [33]byte + + // Address are addresses on which the node is accepting incoming + // connections. + Addresses []net.Addr + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData ExtraOpaqueData +} + +// A compile time check to ensure NodeAnnouncement2 implements the +// lnwire.Message interface. +var _ Message = (*NodeAnnouncement2)(nil) + +// Decode deserializes a serialized AnnounceSignatures stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &n.Signature) + if err != nil { + return err + } + n.Signature.ForceSchnorr() + + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + featuresRecordProducer := NewRawFeatureVectorRecordProducer( + NodeAnn2FeaturesType, + ) + + var ( + rbgColour color.RGBA + alias []byte + ipv4 IPV4Addrs + ipv6 IPV6Addrs + torV3 TorV3Addrs + ) + records := []tlv.Record{ + featuresRecordProducer.Record(), + tlv.MakeStaticRecord( + NodeAnn2RGBColorType, &rbgColour, 3, rgbEncoder, + rgbDecoder, + ), + tlv.MakePrimitiveRecord( + NodeAnn2BlockHeightType, &n.BlockHeight, + ), + tlv.MakePrimitiveRecord(NodeAnn2AliasType, &alias), + tlv.MakePrimitiveRecord(NodeAnn2NodeIDType, &n.NodeID), + tlv.MakeDynamicRecord( + NodeAnn2IPV4AddrsType, &ipv4, ipv4.EncodedSize, + ipv4AddrsEncoder, ipv4AddrsDecoder, + ), + tlv.MakeDynamicRecord( + NodeAnn2IPV6AddrsType, &ipv6, ipv6.EncodedSize, + ipv6AddrsEncoder, ipv6AddrsDecoder, + ), + tlv.MakeDynamicRecord( + NodeAnn2TorV3AddrsType, &torV3, torV3.EncodedSize, + torV3AddrsEncoder, torV3AddrsDecoder, + ), + } + + typeMap, err := tlvRecords.ExtractRecords(records...) + if err != nil { + return err + } + + if _, ok := typeMap[NodeAnn2FeaturesType]; ok { + n.Features = featuresRecordProducer.RawFeatureVector + } + + if _, ok := typeMap[NodeAnn2RGBColorType]; ok { + n.RGBColor = &rbgColour + } + + if _, ok := typeMap[NodeAnn2AliasType]; ok { + // TODO(elle): do this before we allocate the bytes for it + // somehow? + if len(alias) > 32 { + return fmt.Errorf("alias too large: max is %v, got %v", + 32, len(alias)) + } + + // Validate the alias. + if !utf8.ValidString(string(alias)) { + return fmt.Errorf("node alias has non-utf8 characters") + } + + n.Alias = alias + } + + if _, ok := typeMap[NodeAnn2IPV4AddrsType]; ok { + for _, addr := range ipv4 { + n.Addresses = append(n.Addresses, net.Addr(addr)) + } + } + + if _, ok := typeMap[NodeAnn2IPV6AddrsType]; ok { + for _, addr := range ipv6 { + n.Addresses = append(n.Addresses, net.Addr(addr)) + } + } + + if _, ok := typeMap[NodeAnn2TorV3AddrsType]; ok { + for _, addr := range torV3 { + n.Addresses = append(n.Addresses, net.Addr(addr)) + } + } + + if len(tlvRecords) != 0 { + n.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target AnnounceSignatures into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + _, err := w.Write(n.Signature.RawBytes()) + if err != nil { + return err + } + + featuresRecordProducer := &RawFeatureVectorRecordProducer{ + RawFeatureVector: n.Features, + Type: NodeAnn2FeaturesType, + } + + records := []tlv.Record{ + featuresRecordProducer.Record(), + } + + // Only encode the colour if it is specified. + if n.RGBColor != nil { + records = append(records, tlv.MakeStaticRecord( + NodeAnn2RGBColorType, n.RGBColor, 3, rgbEncoder, + rgbDecoder, + )) + } + + records = append(records, tlv.MakePrimitiveRecord( + NodeAnn2BlockHeightType, &n.BlockHeight, + )) + + if len(n.Alias) != 0 { + records = append(records, + tlv.MakePrimitiveRecord(NodeAnn2AliasType, &n.Alias), + ) + } + + records = append( + records, tlv.MakePrimitiveRecord(NodeAnn2NodeIDType, &n.NodeID), + ) + + // Iterate over the addresses and collect the various types. + var ( + ipv4 IPV4Addrs + ipv6 IPV6Addrs + torv3 TorV3Addrs + ) + for _, addr := range n.Addresses { + switch a := addr.(type) { + case *net.TCPAddr: + if a.IP.To4() != nil { + ipv4 = append(ipv4, a) + } else { + ipv6 = append(ipv6, a) + } + + case *tor.OnionAddr: + torv3 = append(torv3, a) + } + } + + if len(ipv4) > 0 { + records = append(records, tlv.MakeDynamicRecord( + NodeAnn2IPV4AddrsType, &ipv4, ipv4.EncodedSize, + ipv4AddrsEncoder, ipv4AddrsDecoder, + )) + } + + if len(ipv6) > 0 { + records = append(records, tlv.MakeDynamicRecord( + NodeAnn2IPV6AddrsType, &ipv6, ipv6.EncodedSize, + ipv6AddrsEncoder, ipv6AddrsDecoder, + )) + } + + if len(torv3) > 0 { + records = append(records, tlv.MakeDynamicRecord( + NodeAnn2TorV3AddrsType, &torv3, torv3.EncodedSize, + torV3AddrsEncoder, torV3AddrsDecoder, + )) + } + + err = EncodeMessageExtraDataFromRecords(&n.ExtraOpaqueData, records...) + if err != nil { + return err + } + + return WriteBytes(w, n.ExtraOpaqueData) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (n *NodeAnnouncement2) MsgType() MessageType { + return MsgNodeAnnouncement2 +} + +func rgbEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*color.RGBA); ok { + buf := bytes.NewBuffer(nil) + err := WriteColorRGBA(buf, *v) + if err != nil { + return err + } + + _, err = w.Write(buf.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "color.RGBA") +} + +func rgbDecoder(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*color.RGBA); ok { + return ReadElements(r, &v.R, &v.G, &v.B) + } + + return tlv.NewTypeForDecodingErr(val, "color.RGBA", l, 3) +} + +// IPV4Addrs is a list of ipv4 addresses that can be encoded as a TLV record. +type IPV4Addrs []*net.TCPAddr + +// ipv4AddrEncodedSize is the number of bytes required to encode a single ipv4 +// address. Four bytes are used to encode the IP address and two bytes for the +// port number. +const ipv4AddrEncodedSize = 4 + 2 + +// EncodedSize returns the number of bytes required to encode an IPV4Addrs +// variable. +func (i *IPV4Addrs) EncodedSize() uint64 { + return uint64(len(*i) * ipv4AddrEncodedSize) +} + +func ipv4AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV4Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To4()) + if err != nil { + return err + } + + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + + _, err = w.Write(port[:]) + + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +func ipv4AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*IPV4Addrs); ok { + if l%(ipv4AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv4 list encoding") + } + + var ( + numAddrs = int(l / ipv4AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [4]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + + _, err = r.Read(port[:]) + if err != nil { + return err + } + + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + + *v = addrs + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV4Addrs") +} + +// IPV6Addrs is a list of ipv6 addresses that can be encoded as a TLV record. +type IPV6Addrs []*net.TCPAddr + +// ipv6AddrEncodedSize is the number of bytes required to encode a single ipv6 +// address. Sixteen bytes are used to encode the IP address and two bytes for +// the port number. +const ipv6AddrEncodedSize = 16 + 2 + +// EncodedSize returns the number of bytes required to encode an IPV6Addrs +// variable. +func (i *IPV6Addrs) EncodedSize() uint64 { + return uint64(len(*i) * ipv6AddrEncodedSize) +} + +func ipv6AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*IPV6Addrs); ok { + for _, ip := range *v { + _, err := w.Write(ip.IP.To16()) + if err != nil { + return err + } + + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(ip.Port)) + + _, err = w.Write(port[:]) + + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +func ipv6AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*IPV6Addrs); ok { + if l%(ipv6AddrEncodedSize) != 0 { + return fmt.Errorf("invalid ipv6 list encoding") + } + + var ( + numAddrs = int(l / ipv6AddrEncodedSize) + addrs = make([]*net.TCPAddr, 0, numAddrs) + ip [16]byte + port [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + + _, err = r.Read(port[:]) + if err != nil { + return err + } + + addrs = append(addrs, &net.TCPAddr{ + IP: ip[:], + Port: int(binary.BigEndian.Uint16(port[:])), + }) + } + + *v = addrs + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.IPV6Addrs") +} + +// TorV3Addrs is a list of tor v3 addresses that can be encoded as a TLV record. +type TorV3Addrs []*tor.OnionAddr + +// torV3AddrEncodedSize is the number of bytes required to encode a single tor +// v3 address. +const torV3AddrEncodedSize = tor.V3DecodedLen + 2 + +// EncodedSize returns the number of bytes required to encode an TorV3Addrs +// variable. +func (i *TorV3Addrs) EncodedSize() uint64 { + return uint64(len(*i) * torV3AddrEncodedSize) +} + +func torV3AddrsEncoder(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*TorV3Addrs); ok { + for _, addr := range *v { + encodedHostLen := tor.V3Len - tor.OnionSuffixLen + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:encodedHostLen], + ) + if err != nil { + return err + } + + if len(host) != tor.V3DecodedLen { + return fmt.Errorf("expected a tor v3 host "+ + "length of %d, got: %d", + tor.V2DecodedLen, len(host)) + } + + if _, err = w.Write(host); err != nil { + return err + } + + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(addr.Port)) + + _, err = w.Write(port[:]) + + return err + } + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +} + +func torV3AddrsDecoder(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*TorV3Addrs); ok { + if l%torV3AddrEncodedSize != 0 { + return fmt.Errorf("invalid tor v3 list encoding") + } + + var ( + numAddrs = int(l / torV3AddrEncodedSize) + addrs = make([]*tor.OnionAddr, 0, numAddrs) + ip [tor.V3DecodedLen]byte + p [2]byte + ) + for len(addrs) < numAddrs { + _, err := r.Read(ip[:]) + if err != nil { + return err + } + + _, err = r.Read(p[:]) + if err != nil { + return err + } + + onionService := tor.Base32Encoding.EncodeToString(ip[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + addrs = append(addrs, &tor.OnionAddr{ + OnionService: onionService, + Port: port, + }) + } + + *v = addrs + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.TorV3Addrs") +} From f811611c765aa286f6e08e3498d91d76078db60d Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 08:27:39 +0200 Subject: [PATCH 12/33] multi: rename lnwire.ChannelAnnouncement Rename lnwire.ChannelAnnouncement to ChannelAnnouncement1. This is in preparation for the addition of a ChannelAnnouncement interface which will be implemented by both ChannelAnnouncement1 and ChannelAnnouncement2. --- discovery/gossiper.go | 40 +++++++++---------- discovery/gossiper_test.go | 33 ++++++++------- discovery/syncer.go | 2 +- discovery/syncer_test.go | 10 ++--- funding/manager.go | 16 ++++---- funding/manager_test.go | 12 +++--- ...ouncement.go => channel_announcement_1.go} | 20 +++++----- lnwire/lnwire_test.go | 4 +- lnwire/message.go | 4 +- lnwire/message_test.go | 4 +- netann/channel_announcement.go | 4 +- netann/channel_announcement_test.go | 2 +- netann/sign.go | 2 +- peer/brontide.go | 4 +- routing/ann_validation.go | 2 +- routing/router.go | 2 +- routing/validation_barrier.go | 18 ++++----- routing/validation_barrier_test.go | 4 +- 18 files changed, 91 insertions(+), 92 deletions(-) rename lnwire/{channel_announcement.go => channel_announcement_1.go} (87%) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index eed3cee841..893d5d769e 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -420,7 +420,7 @@ type AuthenticatedGossiper struct { // prematureChannelUpdates is a map of ChannelUpdates we have received // that wasn't associated with any channel we know about. We store // them temporarily, such that we can reprocess them when a - // ChannelAnnouncement for the channel is received. + // ChannelAnnouncement1 for the channel is received. prematureChannelUpdates *lru.Cache[uint64, *cachedNetworkMsg] // networkMsgs is a channel that carries new network broadcasted @@ -818,9 +818,9 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // To avoid inserting edges in the graph for our own channels that we // have already closed, we ignore such channel announcements coming // from the remote. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: ownKey := d.selfKey.SerializeCompressed() - ownErr := fmt.Errorf("ignoring remote ChannelAnnouncement " + + ownErr := fmt.Errorf("ignoring remote ChannelAnnouncement1 " + "for own channel") if bytes.Equal(m.NodeID1[:], ownKey) || @@ -980,7 +980,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { switch msg := message.msg.(type) { // Channel announcements are identified by the short channel id field. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: deDupKey := msg.ShortChannelID sender := route.NewVertex(message.source) @@ -1554,7 +1554,7 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, case *lnwire.ChannelUpdate: scid = m.ShortChannelID.ToUint64() - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: scid = m.ShortChannelID.ToUint64() default: @@ -1810,9 +1810,9 @@ func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo, // contains a proof, we can add this proof to our edge. We can end up in this // situation in the case where we create a channel, but for some reason fail // to receive the remote peer's proof, while the remote peer is able to fully -// assemble the proof and craft the ChannelAnnouncement. +// assemble the proof and craft the ChannelAnnouncement1. func (d *AuthenticatedGossiper) processRejectedEdge( - chanAnnMsg *lnwire.ChannelAnnouncement, + chanAnnMsg *lnwire.ChannelAnnouncement1, proof *channeldb.ChannelAuthProof) ([]networkMsg, error) { // First, we'll fetch the state of the channel as we know if from the @@ -1997,7 +1997,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // *creation* of a new channel within the network. This only advertises // the existence of a channel and not yet the routing policies in // either direction of the channel. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: return d.handleChanAnnouncement(nMsg, msg, schedulerOp) // A new authenticated channel edge update has arrived. This indicates @@ -2151,7 +2151,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement, + edge *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate, error) { // Parse the unsigned edge into a channel update. @@ -2188,10 +2188,10 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo, // We'll also create the original channel announcement so the two can // be broadcast along side each other (if necessary), but only if we // have a full channel announcement for this channel. - var chanAnn *lnwire.ChannelAnnouncement + var chanAnn *lnwire.ChannelAnnouncement1 if info.AuthProof != nil { chanID := lnwire.NewShortChanIDFromInt(info.ChannelID) - chanAnn = &lnwire.ChannelAnnouncement{ + chanAnn = &lnwire.ChannelAnnouncement1{ ShortChannelID: chanID, NodeID1: info.NodeKey1Bytes, NodeID2: info.NodeKey2Bytes, @@ -2363,16 +2363,16 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, // handleChanAnnouncement processes a new channel announcement. func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, - ann *lnwire.ChannelAnnouncement, + ann *lnwire.ChannelAnnouncement1, ops []batch.SchedulerOption) ([]networkMsg, bool) { - log.Debugf("Processing ChannelAnnouncement: peer=%v, short_chan_id=%v", + log.Debugf("Processing ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, ann.ShortChannelID.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { - err := fmt.Errorf("ignoring ChannelAnnouncement from chain=%v"+ + err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ ", gossiper on chain=%v", ann.ChainHash, d.cfg.ChainHash) log.Errorf(err.Error()) @@ -2387,7 +2387,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, false } - // If this is a remote ChannelAnnouncement with an alias SCID, we'll + // If this is a remote ChannelAnnouncement1 with an alias SCID, we'll // reject the announcement. Since the router accepts alias SCIDs, // not erroring out would be a DoS vector. if nMsg.isRemote && d.cfg.IsAlias(ann.ShortChannelID) { @@ -2630,7 +2630,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil - log.Debugf("Processed ChannelAnnouncement: peer=%v, short_chan_id=%v", + log.Debugf("Processed ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, ann.ShortChannelID.ToUint64()) return announcements, true @@ -2738,7 +2738,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } // We'll fallthrough to ensure we stash the update until we - // receive its corresponding ChannelAnnouncement. This is + // receive its corresponding ChannelAnnouncement1. This is // needed to ensure the edge exists in the graph before // applying the update. fallthrough @@ -2750,15 +2750,15 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // If the edge corresponding to this ChannelUpdate was not // found in the graph, this might be a channel in the process // of being opened, and we haven't processed our own - // ChannelAnnouncement yet, hence it is not not found in the + // ChannelAnnouncement yet, hence it is not found in the // graph. This usually gets resolved after the channel proofs // are exchanged and the channel is broadcasted to the rest of // the network, but in case this is a private channel this // won't ever happen. This can also happen in the case of a // zombie channel with a fresh update for which we don't have a - // ChannelAnnouncement for since we reject them. Because of + // ChannelAnnouncement1 for since we reject them. Because of // this, we temporarily add it to a map, and reprocess it after - // our own ChannelAnnouncement has been processed. + // our own ChannelAnnouncement1 has been processed. // // The shortChanID may be an alias, but it is fine to use here // since we don't have an edge in the graph and if the peer is diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index d364b80ecc..d0fbfc5bc8 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -464,7 +464,7 @@ type annBatch struct { nodeAnn1 *lnwire.NodeAnnouncement nodeAnn2 *lnwire.NodeAnnouncement - chanAnn *lnwire.ChannelAnnouncement + chanAnn *lnwire.ChannelAnnouncement1 chanUpdAnn1 *lnwire.ChannelUpdate chanUpdAnn2 *lnwire.ChannelUpdate @@ -623,9 +623,9 @@ func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error { func createAnnouncementWithoutProof(blockHeight uint32, key1, key2 *btcec.PublicKey, - extraBytes ...[]byte) *lnwire.ChannelAnnouncement { + extraBytes ...[]byte) *lnwire.ChannelAnnouncement1 { - a := &lnwire.ChannelAnnouncement{ + a := &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, TxIndex: 0, @@ -645,13 +645,13 @@ func createAnnouncementWithoutProof(blockHeight uint32, } func createRemoteChannelAnnouncement(blockHeight uint32, - extraBytes ...[]byte) (*lnwire.ChannelAnnouncement, error) { + extraBytes ...[]byte) (*lnwire.ChannelAnnouncement1, error) { return createChannelAnnouncement(blockHeight, remoteKeyPriv1, remoteKeyPriv2, extraBytes...) } func createChannelAnnouncement(blockHeight uint32, key1, key2 *btcec.PrivateKey, - extraBytes ...[]byte) (*lnwire.ChannelAnnouncement, error) { + extraBytes ...[]byte) (*lnwire.ChannelAnnouncement1, error) { a := createAnnouncementWithoutProof(blockHeight, key1.PubKey(), key2.PubKey(), extraBytes...) @@ -1566,7 +1566,7 @@ out: // TestSignatureAnnouncementFullProofWhenRemoteProof tests that if a remote // proof is received when we already have the full proof, the gossiper will send -// the full proof (ChannelAnnouncement) to the remote peer. +// the full proof (ChannelAnnouncement1) to the remote peer. func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { t.Parallel() @@ -1728,7 +1728,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { } // Now give the gossiper the remote proof yet again. This should - // trigger a send of the full ChannelAnnouncement. + // trigger a send of the full ChannelAnnouncement1. select { case err = <-ctx.gossiper.ProcessRemoteAnnouncement( batch.remoteProofAnn, remotePeer, @@ -1741,10 +1741,9 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // We expect the gossiper to send this message to the remote peer. select { case msg := <-sentToPeer: - _, ok := msg.(*lnwire.ChannelAnnouncement) - if !ok { - t.Fatalf("expected ChannelAnnouncement, instead got %T", msg) - } + _, ok := msg.(*lnwire.ChannelAnnouncement1) + require.Truef(t, ok, + "expected ChannelAnnouncement1, instead got %T", msg) case <-time.After(2 * time.Second): t.Fatal("did not send local proof to peer") } @@ -2058,7 +2057,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // Now, we'll attempt to forward the NodeAnnouncement for the same node // by opening a public channel on the network. We'll create a - // ChannelAnnouncement and hand it off to the gossiper in order to + // ChannelAnnouncement1 and hand it off to the gossiper in order to // process it. remoteChanAnn, err := createRemoteChannelAnnouncement(startingHeight - 1) require.NoError(t, err, "unable to create remote channel announcement") @@ -2360,8 +2359,8 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { } // TestReceiveRemoteChannelUpdateFirst tests that if we receive a ChannelUpdate -// from the remote before we have processed our own ChannelAnnouncement, it will -// be reprocessed later, after our ChannelAnnouncement. +// from the remote before we have processed our own ChannelAnnouncement1, it will +// be reprocessed later, after our ChannelAnnouncement1. func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Parallel() @@ -2388,7 +2387,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { } // Recreate the case where the remote node is sending us its ChannelUpdate - // before we have been able to process our own ChannelAnnouncement and + // before we have been able to process our own ChannelAnnouncement1 and // ChannelUpdate. errRemoteAnn := ctx.gossiper.ProcessRemoteAnnouncement( batch.chanUpdAnn2, remotePeer, @@ -2557,7 +2556,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { } // TestExtraDataChannelAnnouncementValidation tests that we're able to properly -// validate a ChannelAnnouncement that includes opaque bytes that we don't +// validate a ChannelAnnouncement1 that includes opaque bytes that we don't // currently know of. func TestExtraDataChannelAnnouncementValidation(t *testing.T) { t.Parallel() @@ -2776,7 +2775,7 @@ func TestRetransmit(t *testing.T) { var chanAnn, chanUpd, nodeAnn int for _, msg := range anns { switch msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: chanAnn++ case *lnwire.ChannelUpdate: chanUpd++ diff --git a/discovery/syncer.go b/discovery/syncer.go index 722519fb08..72174ea666 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1314,7 +1314,7 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // For each channel announcement message, we'll only send this // message if the channel updates for the channel are between // our time range. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: // First, we'll check if the channel updates are in // this message batch. chanUpdates, ok := chanUpdateIndex[msg.ShortChannelID] diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 70920f9f48..ef88e5245a 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -283,7 +283,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, { // Ann tuple below horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(10), }, }, @@ -295,7 +295,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, { // Ann tuple above horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(15), }, }, @@ -307,7 +307,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, { // Ann tuple beyond horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), }, }, @@ -320,7 +320,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { { // Ann w/o an update at all, the update in the DB will // be below the horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), }, }, @@ -683,7 +683,7 @@ func TestGossipSyncerReplyShortChanIDs(t *testing.T) { } queryReply := []lnwire.Message{ - &lnwire.ChannelAnnouncement{ + &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), }, &lnwire.ChannelUpdate{ diff --git a/funding/manager.go b/funding/manager.go index bee3f14817..3b158af901 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -3295,7 +3295,7 @@ func (f *Manager) receivedChannelReady(node *btcec.PublicKey, } // extractAnnounceParams extracts the various channel announcement and update -// parameters that will be needed to construct a ChannelAnnouncement and a +// parameters that will be needed to construct a ChannelAnnouncement1 and a // ChannelUpdate. func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( lnwire.MilliSatoshi, lnwire.MilliSatoshi) { @@ -3324,7 +3324,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( return fwdMinHTLC, fwdMaxHTLC } -// addToRouterGraph sends a ChannelAnnouncement and a ChannelUpdate to the +// addToRouterGraph sends a ChannelAnnouncement1 and a ChannelUpdate to the // gossiper so that the channel is added to the Router's internal graph. // These announcement messages are NOT broadcasted to the greater network, // only to the channel counter party. The proofs required to announce the @@ -3353,7 +3353,7 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, "announcement: %v", err) } - // Send ChannelAnnouncement and ChannelUpdate to the gossiper to add + // Send ChannelAnnouncement1 and ChannelUpdate to the gossiper to add // to the Router's topology. errChan := f.cfg.SendAnnouncement( ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity), @@ -3366,7 +3366,7 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, routing.ErrIgnored) { log.Debugf("Router rejected "+ - "ChannelAnnouncement: %v", err) + "ChannelAnnouncement1: %v", err) } else { return fmt.Errorf("error sending channel "+ "announcement: %v", err) @@ -4050,7 +4050,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // chanAnnouncement encapsulates the two authenticated announcements that we // send out to the network after a new channel has been created locally. type chanAnnouncement struct { - chanAnn *lnwire.ChannelAnnouncement + chanAnn *lnwire.ChannelAnnouncement1 chanUpdateAnn *lnwire.ChannelUpdate chanProof *lnwire.AnnounceSignatures } @@ -4075,7 +4075,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, // The unconditional section of the announcement is the ShortChannelID // itself which compactly encodes the location of the funding output // within the blockchain. - chanAnn := &lnwire.ChannelAnnouncement{ + chanAnn := &lnwire.ChannelAnnouncement1{ ShortChannelID: shortChanID, Features: lnwire.NewRawFeatureVector(), ChainHash: chainHash, @@ -4088,7 +4088,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, // TODO(roasbeef): temp, remove after gossip 1.5 if chanType.IsTaproot() { log.Debugf("Applying taproot feature bit to "+ - "ChannelAnnouncement for %v", chanID) + "ChannelAnnouncement1 for %v", chanID) chanAnn.Features.Set( lnwire.SimpleTaprootChannelsRequiredStaging, @@ -4293,7 +4293,7 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, } // We only send the channel proof announcement and the node announcement - // because addToRouterGraph previously sent the ChannelAnnouncement and + // because addToRouterGraph previously sent the ChannelAnnouncement1 and // the ChannelUpdate announcement messages. The channel proof and node // announcements are broadcast to the greater network. errChan := f.cfg.SendAnnouncement(ann.chanProof) diff --git a/funding/manager_test.go b/funding/manager_test.go index 2c77d377e6..62cc231acd 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -1144,7 +1144,7 @@ func assertAddedToRouterGraph(t *testing.T, alice, bob *testNode, } // assertChannelAnnouncements checks that alice and bob both sends the expected -// announcements (ChannelAnnouncement, ChannelUpdate) after the funding tx has +// announcements (ChannelAnnouncement1, ChannelUpdate) after the funding tx has // confirmed. The last arguments can be set if we expect the nodes to advertise // custom min_htlc values as part of their ChannelUpdate. We expect Alice to // advertise the value required by Bob and vice versa. If they are not set the @@ -1175,9 +1175,9 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, // After the ChannelReady message is sent, Alice and Bob will each send // the following messages to their gossiper: - // 1) ChannelAnnouncement + // 1) ChannelAnnouncement1 // 2) ChannelUpdate - // The ChannelAnnouncement is kept locally, while the ChannelUpdate is + // The ChannelAnnouncement1 is kept locally, while the ChannelUpdate is // sent directly to the other peer, so the edge policies are known to // both peers. nodes := []*testNode{alice, bob} @@ -1195,7 +1195,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, gotChannelUpdate := false for _, msg := range announcements { switch m := msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: gotChannelAnnouncement = true case *lnwire.ChannelUpdate: @@ -1244,7 +1244,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, require.Truef( t, gotChannelAnnouncement, - "ChannelAnnouncement from %d", j, + "ChannelAnnouncement1 from %d", j, ) require.Truef(t, gotChannelUpdate, "ChannelUpdate from %d", j) @@ -4548,7 +4548,7 @@ func testZeroConf(t *testing.T, chanType *lnwire.ChannelType) { // We'll assert that they both create new links. assertHandleChannelReady(t, alice, bob) - // We'll now assert that both sides send ChannelAnnouncement and + // We'll now assert that both sides send ChannelAnnouncement1 and // ChannelUpdate messages. assertChannelAnnouncements( t, alice, bob, fundingAmt, nil, nil, nil, nil, diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement_1.go similarity index 87% rename from lnwire/channel_announcement.go rename to lnwire/channel_announcement_1.go index 2b34c0f990..86d3335614 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement_1.go @@ -7,10 +7,10 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" ) -// ChannelAnnouncement message is used to announce the existence of a channel +// ChannelAnnouncement1 message is used to announce the existence of a channel // between two peers in the overlay, which is propagated by the discovery // service over broadcast handler. -type ChannelAnnouncement struct { +type ChannelAnnouncement1 struct { // This signatures are used by nodes in order to create cross // references between node's channel and node. Requiring both nodes // to sign indicates they are both willing to route other payments via @@ -58,15 +58,15 @@ type ChannelAnnouncement struct { ExtraOpaqueData ExtraOpaqueData } -// A compile time check to ensure ChannelAnnouncement implements the +// A compile time check to ensure ChannelAnnouncement1 implements the // lnwire.Message interface. -var _ Message = (*ChannelAnnouncement)(nil) +var _ Message = (*ChannelAnnouncement1)(nil) -// Decode deserializes a serialized ChannelAnnouncement stored in the passed +// Decode deserializes a serialized ChannelAnnouncement1 stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { +func (a *ChannelAnnouncement1) Decode(r io.Reader, _ uint32) error { return ReadElements(r, &a.NodeSig1, &a.NodeSig2, @@ -83,11 +83,11 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { ) } -// Encode serializes the target ChannelAnnouncement into the passed io.Writer +// Encode serializes the target ChannelAnnouncement1 into the passed io.Writer // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { +func (a *ChannelAnnouncement1) Encode(w *bytes.Buffer, _ uint32) error { if err := WriteSig(w, a.NodeSig1); err != nil { return err } @@ -139,13 +139,13 @@ func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { // wire. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) MsgType() MessageType { +func (a *ChannelAnnouncement1) MsgType() MessageType { return MsgChannelAnnouncement } // DataToSign is used to retrieve part of the announcement message which should // be signed. -func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { +func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { // We should not include the signatures itself. b := make([]byte, 0, MaxMsgBody) buf := bytes.NewBuffer(b) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 21a12ede51..e559d064cc 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -781,7 +781,7 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error - req := ChannelAnnouncement{ + req := ChannelAnnouncement1{ ShortChannelID: NewShortChanIDFromInt( uint64(r.Int63()), ), @@ -1406,7 +1406,7 @@ func TestLightningWireProtocol(t *testing.T) { }, { msgType: MsgChannelAnnouncement, - scenario: func(m ChannelAnnouncement) bool { + scenario: func(m ChannelAnnouncement1) bool { return mainScenario(&m) }, }, diff --git a/lnwire/message.go b/lnwire/message.go index cebc7f4020..1a2a2ad9fe 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -115,7 +115,7 @@ func (t MessageType) String() string { case MsgError: return "Error" case MsgChannelAnnouncement: - return "ChannelAnnouncement" + return "ChannelAnnouncement1" case MsgChannelUpdate: return "ChannelUpdate" case MsgNodeAnnouncement: @@ -227,7 +227,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { case MsgError: msg = &Error{} case MsgChannelAnnouncement: - msg = &ChannelAnnouncement{} + msg = &ChannelAnnouncement1{} case MsgChannelUpdate: msg = &ChannelUpdate{} case MsgNodeAnnouncement: diff --git a/lnwire/message_test.go b/lnwire/message_test.go index bbb434785f..540539b28a 100644 --- a/lnwire/message_test.go +++ b/lnwire/message_test.go @@ -645,11 +645,11 @@ func newMsgChannelReestablish(t testing.TB, } func newMsgChannelAnnouncement(t testing.TB, - r *rand.Rand) *lnwire.ChannelAnnouncement { + r *rand.Rand) *lnwire.ChannelAnnouncement1 { t.Helper() - msg := &lnwire.ChannelAnnouncement{ + msg := &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(r.Int63())), Features: rawFeatureVector(), NodeID1: randRawKey(t), diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 0ae8d606d4..9acd9adbbb 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -14,14 +14,14 @@ import ( // peer's initial routing table upon connect. func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof, chanInfo *channeldb.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement, + e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { // First, using the parameters of the channel, along with the channel // authentication chanProof, we'll create re-create the original // authenticated channel announcement. chanID := lnwire.NewShortChanIDFromInt(chanInfo.ChannelID) - chanAnn := &lnwire.ChannelAnnouncement{ + chanAnn := &lnwire.ChannelAnnouncement1{ ShortChannelID: chanID, NodeID1: chanInfo.NodeKey1Bytes, NodeID2: chanInfo.NodeKey2Bytes, diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index bc2460b9af..0f35003dbf 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -24,7 +24,7 @@ func TestCreateChanAnnouncement(t *testing.T) { t.Fatalf("unable to encode features: %v", err) } - expChanAnn := &lnwire.ChannelAnnouncement{ + expChanAnn := &lnwire.ChannelAnnouncement1{ ChainHash: chainhash.Hash{0x1}, ShortChannelID: lnwire.ShortChannelID{BlockHeight: 1}, NodeID1: key, diff --git a/netann/sign.go b/netann/sign.go index 86634f6281..93bd8cdc9b 100644 --- a/netann/sign.go +++ b/netann/sign.go @@ -20,7 +20,7 @@ func SignAnnouncement(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, ) switch m := msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: data, err = m.DataToSign() case *lnwire.ChannelUpdate: data, err = m.DataToSign() diff --git a/peer/brontide.go b/peer/brontide.go index 86aa99aae1..dc05b2ef16 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1725,7 +1725,7 @@ out: } case *lnwire.ChannelUpdate, - *lnwire.ChannelAnnouncement, + *lnwire.ChannelAnnouncement1, *lnwire.NodeAnnouncement, *lnwire.AnnounceSignatures, *lnwire.GossipTimestampRange, @@ -1978,7 +1978,7 @@ func messageSummary(msg lnwire.Message) string { return fmt.Sprintf("chan_id=%v, short_chan_id=%v", msg.ChannelID, msg.ShortChannelID.ToUint64()) - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v", msg.ChainHash, msg.ShortChannelID.ToUint64()) diff --git a/routing/ann_validation.go b/routing/ann_validation.go index 35f0afd4d1..a60944bf8a 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -15,7 +15,7 @@ import ( // ValidateChannelAnn validates the channel announcement message and checks // that node signatures covers the announcement message, and that the bitcoin // signatures covers the node keys. -func ValidateChannelAnn(a *lnwire.ChannelAnnouncement) error { +func ValidateChannelAnn(a *lnwire.ChannelAnnouncement1) error { // First, we'll compute the digest (h) which is to be signed by each of // the keys included within the node announcement message. This hash // digest includes all the keys, so the (up to 4 signatures) will diff --git a/routing/router.go b/routing/router.go index 827e218641..6db9da6cdc 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1567,7 +1567,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // graph. If the passed ShortChannelID is an alias, then we'll // skip validation as it will not map to a legitimate tx. This // is not a DoS vector as only we can add an alias - // ChannelAnnouncement from the gossiper. + // ChannelAnnouncement1 from the gossiper. scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) if r.cfg.AssumeChannelValid || r.cfg.IsAlias(scid) { if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index a7c6561acd..5f898c2242 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -35,13 +35,13 @@ type ValidationBarrier struct { validationSemaphore chan struct{} // chanAnnFinSignal is map that keep track of all the pending - // ChannelAnnouncement like validation job going on. Once the job has + // ChannelAnnouncement1 like validation job going on. Once the job has // been completed, the channel will be closed unblocking any // dependants. chanAnnFinSignal map[lnwire.ShortChannelID]*validationSignals // chanEdgeDependencies tracks any channel edge updates which should - // wait until the completion of the ChannelAnnouncement before + // wait until the completion of the ChannelAnnouncement1 before // proceeding. This is a dependency, as we can't validate the update // before we validate the announcement which creates the channel // itself. @@ -49,7 +49,7 @@ type ValidationBarrier struct { // nodeAnnDependencies tracks any pending NodeAnnouncement validation // jobs which should wait until the completion of the - // ChannelAnnouncement before proceeding. + // ChannelAnnouncement1 before proceeding. nodeAnnDependencies map[route.Vertex]*validationSignals quit chan struct{} @@ -101,12 +101,12 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // ChannelUpdates for the same channel, or NodeAnnouncements of nodes // that are involved in this channel. This goes for both the wire // type,s and also the types that we use within the database. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: // We ensure that we only create a new announcement signal iff, // one doesn't already exist, as there may be duplicate // announcements. We'll close this signal once the - // ChannelAnnouncement has been validated. This will result in + // ChannelAnnouncement1 has been validated. This will result in // all the dependent jobs being unlocked so they can finish // execution themselves. if _, ok := v.chanAnnFinSignal[msg.ShortChannelID]; !ok { @@ -186,7 +186,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { switch msg := job.(type) { // Any ChannelUpdate or NodeAnnouncement jobs will need to wait on the - // completion of any active ChannelAnnouncement jobs related to them. + // completion of any active ChannelAnnouncement1 jobs related to them. case *channeldb.ChannelEdgePolicy: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) signals, ok = v.chanEdgeDependencies[shortID] @@ -218,7 +218,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { case *lnwire.AnnounceSignatures: // TODO(roasbeef): need to wait on chan ann? case *channeldb.ChannelEdgeInfo: - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: } // Release the lock once the above read is finished. @@ -260,7 +260,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { switch msg := job.(type) { - // If we've just finished executing a ChannelAnnouncement, then we'll + // If we've just finished executing a ChannelAnnouncement1, then we'll // close out the signal, and remove the signal from the map of active // ones. This will allow/deny any dependent jobs to continue execution. case *channeldb.ChannelEdgeInfo: @@ -274,7 +274,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { } delete(v.chanAnnFinSignal, shortID) } - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: finSignals, ok := v.chanAnnFinSignal[msg.ShortChannelID] if ok { if allow { diff --git a/routing/validation_barrier_test.go b/routing/validation_barrier_test.go index 2eda0120fc..e56c95a472 100644 --- a/routing/validation_barrier_test.go +++ b/routing/validation_barrier_test.go @@ -73,9 +73,9 @@ func TestValidationBarrierQuit(t *testing.T) { // Create a set of unique channel announcements that we will prep for // validation. - anns := make([]*lnwire.ChannelAnnouncement, 0, numTasks) + anns := make([]*lnwire.ChannelAnnouncement1, 0, numTasks) for i := 0; i < numTasks; i++ { - anns = append(anns, &lnwire.ChannelAnnouncement{ + anns = append(anns, &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), NodeID1: nodeIDFromInt(uint64(2 * i)), NodeID2: nodeIDFromInt(uint64(2*i + 1)), From 77d6c0122edf47d04be182701b20fda5462ce5da Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 08:29:27 +0200 Subject: [PATCH 13/33] multi: rename channeldb.ChannelEdgeInfo to ChannelEdgeInfo1 This is in preparation for the addition of a ChannelEdgeInfo interface which will be implemented by ChannelEdgeInfo1 and the coming ChannelEdgeInfo2. --- autopilot/graph.go | 4 +- channeldb/channel_cache_test.go | 2 +- channeldb/graph.go | 106 +++++++++++++-------------- channeldb/graph_cache.go | 8 +- channeldb/graph_cache_test.go | 6 +- channeldb/graph_test.go | 42 +++++------ discovery/gossiper.go | 16 ++-- discovery/gossiper_test.go | 16 ++-- lnrpc/devrpc/dev_server.go | 2 +- lnrpc/invoicesrpc/addinvoice.go | 2 +- lnrpc/invoicesrpc/addinvoice_test.go | 30 ++++---- netann/chan_status_manager_test.go | 15 ++-- netann/channel_announcement.go | 2 +- netann/channel_announcement_test.go | 2 +- netann/channel_update.go | 6 +- netann/interface.go | 2 +- routing/localchans/manager.go | 4 +- routing/localchans/manager_test.go | 12 +-- routing/notifications.go | 4 +- routing/notifications_test.go | 8 +- routing/pathfind_test.go | 4 +- routing/router.go | 22 +++--- routing/router_test.go | 30 ++++---- routing/validation_barrier.go | 6 +- rpcserver.go | 8 +- server.go | 2 +- 26 files changed, 181 insertions(+), 180 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 0b062a5358..a049833250 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -87,7 +87,7 @@ func (d dbNode) Addrs() []net.Addr { // NOTE: Part of the autopilot.Node interface. func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.node.ForEachChannel(d.tx, func(tx kvdb.RTx, - ei *channeldb.ChannelEdgeInfo, ep, _ *channeldb.ChannelEdgePolicy) error { + ei *channeldb.ChannelEdgeInfo1, ep, _ *channeldb.ChannelEdgePolicy) error { // Skip channels for which no outgoing edge policy is available. // @@ -222,7 +222,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, } chanID := randChanID() - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), Capacity: capacity, } diff --git a/channeldb/channel_cache_test.go b/channeldb/channel_cache_test.go index d776c1318c..72927613e3 100644 --- a/channeldb/channel_cache_test.go +++ b/channeldb/channel_cache_test.go @@ -98,7 +98,7 @@ func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { // channelForInt generates a unique ChannelEdge given an integer. func channelForInt(i uint64) ChannelEdge { return ChannelEdge{ - Info: &ChannelEdgeInfo{ + Info: &ChannelEdgeInfo1{ ChannelID: i, }, } diff --git a/channeldb/graph.go b/channeldb/graph.go index 8367aaf68b..8f6dfb2382 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -231,7 +231,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, return nil, err } - err = g.ForEachChannel(func(info *ChannelEdgeInfo, + err = g.ForEachChannel(func(info *ChannelEdgeInfo1, policy1, policy2 *ChannelEdgePolicy) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -411,7 +411,7 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, +func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { return c.db.View(func(tx kvdb.RTx) error { @@ -483,7 +483,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, return err } - dbCallback := func(tx kvdb.RTx, e *ChannelEdgeInfo, p1, + dbCallback := func(tx kvdb.RTx, e *ChannelEdgeInfo1, p1, p2 *ChannelEdgePolicy) error { var cachedInPolicy *CachedEdgePolicy @@ -558,7 +558,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, channels := make(map[uint64]*DirectedChannel) err := node.ForEachChannel(tx, func(tx kvdb.RTx, - e *ChannelEdgeInfo, p1 *ChannelEdgePolicy, + e *ChannelEdgeInfo1, p1 *ChannelEdgePolicy, p2 *ChannelEdgePolicy) error { toNodeCallback := func() route.Vertex { @@ -967,7 +967,7 @@ func (c *ChannelGraph) deleteLightningNode(nodes kvdb.RwBucket, // involved in creation of the channel, and the set of features that the channel // supports. The chanPoint and chanID are used to uniquely identify the edge // globally within the database. -func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo, +func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo1, op ...batch.SchedulerOption) error { var alreadyExists bool @@ -1010,7 +1010,7 @@ func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo, // addChannelEdge is the private form of AddChannelEdge that allows callers to // utilize an existing db transaction. -func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error { +func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo1) error { // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) @@ -1220,7 +1220,7 @@ func (c *ChannelGraph) HasChannelEdge( // In order to maintain this constraints, we return an error in the scenario // that an edge info hasn't yet been created yet, but someone attempts to update // it. -func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { +func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo1) error { // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) @@ -1265,12 +1265,12 @@ const ( // with the current UTXO state. A slice of channels that have been closed by // the target block are returned if the function succeeds without error. func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, - blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) { + blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo1, error) { c.cacheMu.Lock() defer c.cacheMu.Unlock() - var chansClosed []*ChannelEdgeInfo + var chansClosed []*ChannelEdgeInfo1 err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { // First grab the edges bucket which houses the information @@ -1518,7 +1518,7 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // set to the last prune height valid for the remaining chain. // Channels that were removed from the graph resulting from the // disconnected block are returned. -func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo1, error) { // Every channel having a ShortChannelID starting at 'height' @@ -1541,7 +1541,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf defer c.cacheMu.Unlock() // Keep track of the channels that are removed from the graph. - var removedChans []*ChannelEdgeInfo + var removedChans []*ChannelEdgeInfo1 if err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges, err := tx.CreateTopLevelBucket(edgeBucket) @@ -1848,7 +1848,7 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) { // edge as well as each of the known advertised edge policies. type ChannelEdge struct { // Info contains all the static information describing the channel. - Info *ChannelEdgeInfo + Info *ChannelEdgeInfo1 // Policy1 points to the "first" edge policy of the channel containing // the dynamic information required to properly route through the edge. @@ -2438,7 +2438,7 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // the channel. If the channel were to be marked zombie again, it would be // marked with the correct lagging channel since we received an update from only // one side. -func makeZombiePubkeys(info *ChannelEdgeInfo, +func makeZombiePubkeys(info *ChannelEdgeInfo1, e1, e2 *ChannelEdgePolicy) ([33]byte, [33]byte) { switch { @@ -2747,7 +2747,7 @@ func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error) // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := l.ForEachChannel(tx, func(_ kvdb.RTx, info *ChannelEdgeInfo, + err := l.ForEachChannel(tx, func(_ kvdb.RTx, info *ChannelEdgeInfo1, _, _ *ChannelEdgePolicy) error { // If this edge doesn't extend to the source node, we'll @@ -2858,7 +2858,7 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) @@ -2919,7 +2919,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -3017,7 +3017,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] @@ -3026,13 +3026,13 @@ func (l *LightningNode) ForEachChannel(tx kvdb.RTx, return nodeTraversal(tx, nodePub, db, cb) } -// ChannelEdgeInfo represents a fully authenticated channel along with all its +// ChannelEdgeInfo1 represents a fully authenticated channel along with all its // unique attributes. Once an authenticated channel announcement has been -// processed on the network, then an instance of ChannelEdgeInfo encapsulating +// processed on the network, then an instance of ChannelEdgeInfo1 encapsulating // the channels attributes is stored. The other portions relevant to routing // policy of a channel are stored within a ChannelEdgePolicy for each direction // of the channel. -type ChannelEdgeInfo struct { +type ChannelEdgeInfo1 struct { // ChannelID is the unique channel ID for the channel. The first 3 // bytes are the block height, the next 3 the index within the block, // and the last 2 bytes are the output index for the channel. @@ -3090,8 +3090,8 @@ type ChannelEdgeInfo struct { } // AddNodeKeys is a setter-like method that can be used to replace the set of -// keys for the target ChannelEdgeInfo. -func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, +// keys for the target ChannelEdgeInfo1. +func (c *ChannelEdgeInfo1) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, bitcoinKey2 *btcec.PublicKey) { c.nodeKey1 = nodeKey1 @@ -3114,7 +3114,7 @@ func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) NodeKey1() (*btcec.PublicKey, error) { if c.nodeKey1 != nil { return c.nodeKey1, nil } @@ -3136,7 +3136,7 @@ func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) NodeKey2() (*btcec.PublicKey, error) { if c.nodeKey2 != nil { return c.nodeKey2, nil } @@ -3156,7 +3156,7 @@ func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) BitcoinKey1() (*btcec.PublicKey, error) { if c.bitcoinKey1 != nil { return c.bitcoinKey1, nil } @@ -3176,7 +3176,7 @@ func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) BitcoinKey2() (*btcec.PublicKey, error) { if c.bitcoinKey2 != nil { return c.bitcoinKey2, nil } @@ -3192,7 +3192,7 @@ func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { // OtherNodeKeyBytes returns the node key bytes of the other end of // the channel. -func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( +func (c *ChannelEdgeInfo1) OtherNodeKeyBytes(thisNodeKey []byte) ( [33]byte, error) { switch { @@ -3209,7 +3209,7 @@ func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. -func (c *ChannelEdgeInfo) FetchOtherNode(tx kvdb.RTx, +func (c *ChannelEdgeInfo1) FetchOtherNode(tx kvdb.RTx, thisNodeKey []byte) (*LightningNode, error) { // Ensure that the node passed in is actually a member of the channel. @@ -3514,10 +3514,10 @@ func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { var ( - edgeInfo *ChannelEdgeInfo + edgeInfo *ChannelEdgeInfo1 policy1 *ChannelEdgePolicy policy2 *ChannelEdgePolicy ) @@ -3599,12 +3599,12 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // // ErrZombieEdge an be returned if the edge is currently marked as a zombie // within the database. In this case, the ChannelEdgePolicy's will be nil, and -// the ChannelEdgeInfo will only include the public keys of each node. +// the ChannelEdgeInfo1 will only include the public keys of each node. func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { var ( - edgeInfo *ChannelEdgeInfo + edgeInfo *ChannelEdgeInfo1 policy1 *ChannelEdgePolicy policy2 *ChannelEdgePolicy channelID [8]byte @@ -3657,7 +3657,7 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, // populate the edge info with the public keys of each // party as this is the only information we have about // it and return an error signaling so. - edgeInfo = &ChannelEdgeInfo{ + edgeInfo = &ChannelEdgeInfo1{ NodeKey1Bytes: pubKey1, NodeKey2Bytes: pubKey2, } @@ -4309,7 +4309,7 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } -func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error { +func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *ChannelEdgeInfo1, chanID [8]byte) error { var b bytes.Buffer if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { @@ -4376,58 +4376,58 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *ChannelEdgeInfo, chanID } func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, - chanID []byte) (ChannelEdgeInfo, error) { + chanID []byte) (ChannelEdgeInfo1, error) { edgeInfoBytes := edgeIndex.Get(chanID) if edgeInfoBytes == nil { - return ChannelEdgeInfo{}, ErrEdgeNotFound + return ChannelEdgeInfo1{}, ErrEdgeNotFound } edgeInfoReader := bytes.NewReader(edgeInfoBytes) return deserializeChanEdgeInfo(edgeInfoReader) } -func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { +func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo1, error) { var ( err error - edgeInfo ChannelEdgeInfo + edgeInfo ChannelEdgeInfo1 ) if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") if err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } proof := &ChannelAuthProof{} proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if !proof.IsEmpty() { @@ -4436,17 +4436,17 @@ func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { edgeInfo.ChannelPoint = wire.OutPoint{} if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } // We'll try and see if there are any opaque bytes left, if not, then @@ -4458,7 +4458,7 @@ func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { case err == io.ErrUnexpectedEOF: case err == io.EOF: case err != nil: - return ChannelEdgeInfo{}, err + return ChannelEdgeInfo1{}, err } return edgeInfo, nil diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 1aae21d06d..6cb67431fc 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -27,7 +27,7 @@ type GraphCacheNode interface { // error, then the iteration is halted with the error propagated back up // to the caller. ForEachChannel(kvdb.RTx, - func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error } @@ -223,7 +223,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) return node.ForEachChannel( - tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo1, outPolicy *ChannelEdgePolicy, inPolicy *ChannelEdgePolicy) error { @@ -238,7 +238,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { // and policy 2 does not matter, the directionality is extracted from the info // and policy flags automatically. The policy will be set as the outgoing policy // on one node and the incoming policy on the peer's side. -func (c *GraphCache) AddChannel(info *ChannelEdgeInfo, +func (c *GraphCache) AddChannel(info *ChannelEdgeInfo1, policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { if info == nil { @@ -380,7 +380,7 @@ func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { // UpdateChannel updates the channel edge information for a specific edge. We // expect the edge to already exist and be known. If it does not yet exist, this // call is a no-op. -func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) { +func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo1) { c.mtx.Lock() defer c.mtx.Unlock() diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index b408ec36d8..2d0faddfc5 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -28,7 +28,7 @@ type node struct { pubKey route.Vertex features *lnwire.FeatureVector - edgeInfos []*ChannelEdgeInfo + edgeInfos []*ChannelEdgeInfo1 outPolicies []*ChannelEdgePolicy inPolicies []*ChannelEdgePolicy } @@ -41,7 +41,7 @@ func (n *node) Features() *lnwire.FeatureVector { } func (n *node) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { for idx := range n.edgeInfos { @@ -89,7 +89,7 @@ func TestGraphCacheAddNode(t *testing.T) { node := &node{ pubKey: nodeA, features: lnwire.EmptyFeatureVector(), - edgeInfos: []*ChannelEdgeInfo{{ + edgeInfos: []*ChannelEdgeInfo1{{ ChannelID: 1000, // Those are direction independent! NodeKey1Bytes: pubKey1, diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 0afc9ea198..5685f6f2a0 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -327,7 +327,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { require.NoError(t, err, "unable to generate node key") node2Pub, err := node2.PubKey() require.NoError(t, err, "unable to generate node key") - edgeInfo := ChannelEdgeInfo{ + edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, AuthProof: &ChannelAuthProof{ @@ -387,7 +387,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { + node1, node2 *LightningNode) (ChannelEdgeInfo1, lnwire.ShortChannelID) { shortChanID := lnwire.ShortChannelID{ BlockHeight: height, @@ -401,7 +401,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, node1Pub, _ := node1.PubKey() node2Pub, _ := node2.PubKey() - edgeInfo := ChannelEdgeInfo{ + edgeInfo := ChannelEdgeInfo1{ ChannelID: shortChanID.ToUint64(), ChainHash: key, AuthProof: &ChannelAuthProof{ @@ -556,8 +556,8 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } } -func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, - e2 *ChannelEdgeInfo) { +func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo1, + e2 *ChannelEdgeInfo1) { if e1.ChannelID != e2.ChannelID { t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, @@ -618,8 +618,8 @@ func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, } } -func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) { +func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( + *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) { var ( firstNode *LightningNode @@ -643,7 +643,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*ChannelEd // Add the new edge to the database, this should proceed without any // errors. - edgeInfo := &ChannelEdgeInfo{ + edgeInfo := &ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, AuthProof: &ChannelAuthProof{ @@ -822,7 +822,7 @@ func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { } func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, - e *ChannelEdgeInfo) { + e *ChannelEdgeInfo1) { // Let's check the internal view first. require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) @@ -900,7 +900,7 @@ func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { } func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, - e *ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + e *ChannelEdgeInfo1, p *ChannelEdgePolicy, policy1 bool) { // Check the internal state first. c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] @@ -1043,7 +1043,7 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy, + err = graph.ForEachChannel(func(ei *ChannelEdgeInfo1, _ *ChannelEdgePolicy, _ *ChannelEdgePolicy) error { delete(chanIndex, ei.ChannelID) @@ -1056,7 +1056,7 @@ func TestGraphTraversal(t *testing.T) { // outgoing channels for a particular node. numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] - err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo, + err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo1, outEdge, inEdge *ChannelEdgePolicy) error { // All channels between first and second node should have fully @@ -1129,7 +1129,7 @@ func TestGraphTraversalCacheable(t *testing.T) { err = graph.db.View(func(tx kvdb.RTx) error { for _, node := range nodes { err := node.ForEachChannel( - tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo1, policy *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) error { @@ -1249,7 +1249,7 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, Index: 0, } - edgeInfo := ChannelEdgeInfo{ + edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, AuthProof: &ChannelAuthProof{ @@ -1313,7 +1313,7 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 - if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy, + if err := graph.ForEachChannel(func(*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error { numChans++ @@ -1430,7 +1430,7 @@ func TestGraphPruning(t *testing.T) { channelPoints = append(channelPoints, &op) - edgeInfo := ChannelEdgeInfo{ + edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, AuthProof: &ChannelAuthProof{ @@ -2280,7 +2280,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Ensure that channel is reported with unknown policies. checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 - err := node.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo, + err := node.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo1, outEdge, inEdge *ChannelEdgePolicy) error { if !expectedOut && outEdge != nil { @@ -2698,7 +2698,7 @@ func TestNodeIsPublic(t *testing.T) { // After creating all of our nodes and edges, we'll add them to each // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} - edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} + edges := []*ChannelEdgeInfo1{&aliceBobEdge, &bobCarolEdge} dbs := []kvdb.Backend{aliceGraph.db, bobGraph.db, carolGraph.db} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for i, graph := range graphs { @@ -3326,7 +3326,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { // Create a third edge, this with a block height of 155. edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) - edges := []ChannelEdgeInfo{edgeInfo, edgeInfo2, edgeInfo3} + edges := []ChannelEdgeInfo1{edgeInfo, edgeInfo2, edgeInfo3} errChan := make(chan error, len(edges)) errTimeout := errors.New("timeout adding batched channel") @@ -3334,7 +3334,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { var wg sync.WaitGroup for _, edge := range edges { wg.Add(1) - go func(edge ChannelEdgeInfo) { + go func(edge ChannelEdgeInfo1) { defer wg.Done() select { @@ -3443,7 +3443,7 @@ func BenchmarkForEachChannel(b *testing.B) { for _, n := range nodes { err := n.ForEachChannel( tx, func(tx kvdb.RTx, - info *ChannelEdgeInfo, + info *ChannelEdgeInfo1, policy *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) error { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 893d5d769e..e267c9d404 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -527,7 +527,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper // EdgeWithInfo contains the information that is required to update an edge. type EdgeWithInfo struct { // Info describes the channel. - Info *channeldb.ChannelEdgeInfo + Info *channeldb.ChannelEdgeInfo1 // Edge describes the policy in one direction of the channel. Edge *channeldb.ChannelEdgePolicy @@ -1574,7 +1574,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // Iterate over all of our channels and check if any of them fall // within the prune interval or re-broadcast interval. type updateTuple struct { - info *channeldb.ChannelEdgeInfo + info *channeldb.ChannelEdgeInfo1 edge *channeldb.ChannelEdgePolicy } @@ -1584,7 +1584,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { ) err := d.cfg.Router.ForAllOutgoingChannels(func( _ kvdb.RTx, - info *channeldb.ChannelEdgeInfo, + info *channeldb.ChannelEdgeInfo1, edge *channeldb.ChannelEdgePolicy) error { // If there's no auth proof attached to this edge, it means @@ -1789,8 +1789,8 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( } // remotePubFromChanInfo returns the public key of the remote peer given a -// ChannelEdgeInfo that describe a channel we have with them. -func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo, +// ChannelEdgeInfo1 that describe a channel we have with them. +func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo1, chanFlags lnwire.ChanUpdateChanFlags) [33]byte { var remotePubKey [33]byte @@ -2022,7 +2022,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // processZombieUpdate determines whether the provided channel update should // resurrect a given zombie edge. func (d *AuthenticatedGossiper) processZombieUpdate( - chanInfo *channeldb.ChannelEdgeInfo, msg *lnwire.ChannelUpdate) error { + chanInfo *channeldb.ChannelEdgeInfo1, msg *lnwire.ChannelUpdate) error { // The least-significant bit in the flag on the channel update tells us // which edge is being updated. @@ -2150,7 +2150,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. -func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo, +func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, edge *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate, error) { @@ -2465,7 +2465,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, false } - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: ann.ShortChannelID.ToUint64(), ChainHash: ann.ChainHash, NodeKey1Bytes: ann.NodeID1, diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index d0fbfc5bc8..b3313732ca 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -91,7 +91,7 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode - infos map[uint64]channeldb.ChannelEdgeInfo + infos map[uint64]channeldb.ChannelEdgeInfo1 edges map[uint64][]channeldb.ChannelEdgePolicy zombies map[uint64][][33]byte chansToReject map[uint64]struct{} @@ -100,7 +100,7 @@ type mockGraphSource struct { func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, - infos: make(map[uint64]channeldb.ChannelEdgeInfo), + infos: make(map[uint64]channeldb.ChannelEdgeInfo1), edges: make(map[uint64][]channeldb.ChannelEdgePolicy), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), @@ -119,7 +119,7 @@ func (r *mockGraphSource) AddNode(node *channeldb.LightningNode, return nil } -func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo, +func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo1, _ ...batch.SchedulerOption) error { r.mu.Lock() @@ -190,7 +190,7 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) } func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i *channeldb.ChannelEdgeInfo, + i *channeldb.ChannelEdgeInfo1, c *channeldb.ChannelEdgePolicy) error) error { r.mu.Lock() @@ -221,13 +221,13 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, return nil } -func (r *mockGraphSource) ForEachChannel(func(chanInfo *channeldb.ChannelEdgeInfo, +func (r *mockGraphSource) ForEachChannel(func(chanInfo *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) error) error { return nil } func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { @@ -242,7 +242,7 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return nil, nil, nil, channeldb.ErrEdgeNotFound } - return &channeldb.ChannelEdgeInfo{ + return &channeldb.ChannelEdgeInfo1{ NodeKey1Bytes: pubKeys[0], NodeKey2Bytes: pubKeys[1], }, nil, nil, channeldb.ErrZombieEdge @@ -3442,7 +3442,7 @@ out: var edgesToUpdate []EdgeWithInfo err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, - info *channeldb.ChannelEdgeInfo, + info *channeldb.ChannelEdgeInfo1, edge *channeldb.ChannelEdgePolicy) error { edge.TimeLockDelta = uint16(newTimeLockDelta) diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 462328ea33..8cd88c34c0 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -261,7 +261,7 @@ func (s *Server) ImportGraph(ctx context.Context, for _, rpcEdge := range graph.Edges { rpcEdge := rpcEdge - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: rpcEdge.ChannelId, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, Capacity: btcutil.Amount(rpcEdge.Capacity), diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index f300e43804..5396579d70 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -627,7 +627,7 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*channeldb.ChannelEdgeInfo, + FetchChannelEdgesByID func(chanID uint64) (*channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 2dcb516a6d..0cc1f9c516 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -51,7 +51,7 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - *channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { args := h.Mock.Called(chanID) @@ -64,7 +64,7 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( return nil, nil, nil, err } - edgeInfo := args.Get(0).(*channeldb.ChannelEdgeInfo) + edgeInfo := args.Get(0).(*channeldb.ChannelEdgeInfo1) policy1 := args.Get(1).(*channeldb.ChannelEdgePolicy) policy2 := args.Get(2).(*channeldb.ChannelEdgePolicy) @@ -215,7 +215,7 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -253,7 +253,7 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -294,7 +294,7 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{ + &channeldb.ChannelEdgeInfo1{ NodeKey1Bytes: selectedPolicy, }, &channeldb.ChannelEdgePolicy{ @@ -342,7 +342,7 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{ FeeBaseMSat: 1000, @@ -387,7 +387,7 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{ FeeBaseMSat: 1000, @@ -554,7 +554,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -604,7 +604,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -655,7 +655,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -688,7 +688,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -705,7 +705,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -742,7 +742,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -759,7 +759,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) @@ -797,7 +797,7 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgeInfo1{}, &channeldb.ChannelEdgePolicy{}, &channeldb.ChannelEdgePolicy{}, nil, ) diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index e068184f89..65e0bf70bd 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -65,8 +65,9 @@ func createChannel(t *testing.T) *channeldb.OpenChannel { // our `pubkey` with the direction bit set appropriately in the policies. Our // update will be created with the disabled bit set if startEnabled is false. func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, - pubkey *btcec.PublicKey, startEnabled bool) (*channeldb.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy) { + pubkey *btcec.PublicKey, startEnabled bool) ( + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) { var ( pubkey1 [33]byte @@ -98,7 +99,7 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, // bit. dir2 |= lnwire.ChanUpdateDirection - return &channeldb.ChannelEdgeInfo{ + return &channeldb.ChannelEdgeInfo1{ ChannelPoint: channel.FundingOutpoint, NodeKey1Bytes: pubkey1, NodeKey2Bytes: pubkey2, @@ -120,7 +121,7 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, type mockGraph struct { mu sync.Mutex channels []*channeldb.OpenChannel - chanInfos map[wire.OutPoint]*channeldb.ChannelEdgeInfo + chanInfos map[wire.OutPoint]*channeldb.ChannelEdgeInfo1 chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicy chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy sidToCid map[lnwire.ShortChannelID]wire.OutPoint @@ -133,7 +134,7 @@ func newMockGraph(t *testing.T, numChannels int, g := &mockGraph{ channels: make([]*channeldb.OpenChannel, 0, numChannels), - chanInfos: make(map[wire.OutPoint]*channeldb.ChannelEdgeInfo), + chanInfos: make(map[wire.OutPoint]*channeldb.ChannelEdgeInfo1), chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), @@ -159,7 +160,7 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { } func (g *mockGraph) FetchChannelEdgesByOutpoint( - op *wire.OutPoint) (*channeldb.ChannelEdgeInfo, + op *wire.OutPoint) (*channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { g.mu.Lock() @@ -247,7 +248,7 @@ func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { } func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, - info *channeldb.ChannelEdgeInfo, + info *channeldb.ChannelEdgeInfo1, pol1, pol2 *channeldb.ChannelEdgePolicy) { g.mu.Lock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 9acd9adbbb..7ee3ba05ac 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -13,7 +13,7 @@ import ( // structs for announcing new channels to other peers, or simply syncing up a // peer's initial routing table upon connect. func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof, - chanInfo *channeldb.ChannelEdgeInfo, + chanInfo *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index 0f35003dbf..d8126c9cc0 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -45,7 +45,7 @@ func TestCreateChanAnnouncement(t *testing.T) { BitcoinSig1Bytes: expChanAnn.BitcoinSig1.ToSignatureBytes(), BitcoinSig2Bytes: expChanAnn.BitcoinSig2.ToSignatureBytes(), } - chanInfo := &channeldb.ChannelEdgeInfo{ + chanInfo := &channeldb.ChannelEdgeInfo1{ ChainHash: expChanAnn.ChainHash, ChannelID: expChanAnn.ShortChannelID.ToUint64(), ChannelPoint: wire.OutPoint{Index: 1}, diff --git a/netann/channel_update.go b/netann/channel_update.go index b6555f37b1..9880a462b3 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -84,7 +84,7 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, - info *channeldb.ChannelEdgeInfo, + info *channeldb.ChannelEdgeInfo1, policies ...*channeldb.ChannelEdgePolicy) ( *lnwire.ChannelUpdate, error) { @@ -117,7 +117,7 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the // given edge info and policy. -func UnsignedChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo, +func UnsignedChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo1, policy *channeldb.ChannelEdgePolicy) *lnwire.ChannelUpdate { return &lnwire.ChannelUpdate{ @@ -137,7 +137,7 @@ func UnsignedChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo, // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given edge // info and policy. -func ChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo, +func ChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo1, policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate, error) { update := UnsignedChannelUpdateFromEdge(info, policy) diff --git a/netann/interface.go b/netann/interface.go index 79c0c114f2..68afdd32e9 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -18,6 +18,6 @@ type DB interface { type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. - FetchChannelEdgesByOutpoint(*wire.OutPoint) (*channeldb.ChannelEdgeInfo, + FetchChannelEdgesByOutpoint(*wire.OutPoint) (*channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) } diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 3a4e2416ab..b450e14a14 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -31,7 +31,7 @@ type Manager struct { // ForAllOutgoingChannels is required to iterate over all our local // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error // FetchChannel is used to query local channel parameters. Optionally an @@ -73,7 +73,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // otherwise we'll collect them all. err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, - info *channeldb.ChannelEdgeInfo, + info *channeldb.ChannelEdgeInfo1, edge *channeldb.ChannelEdgePolicy) error { // If we have a channel filter, and this channel isn't a part diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index b14093a01d..ca59012dcb 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -22,7 +22,7 @@ func TestManager(t *testing.T) { t.Parallel() type channel struct { - edgeInfo *channeldb.ChannelEdgeInfo + edgeInfo *channeldb.ChannelEdgeInfo1 } var ( @@ -107,7 +107,7 @@ func TestManager(t *testing.T) { } forAllOutgoingChannels := func(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error { for _, c := range channelSet { @@ -166,7 +166,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &channeldb.ChannelEdgeInfo{ + edgeInfo: &channeldb.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -183,7 +183,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &channeldb.ChannelEdgeInfo{ + edgeInfo: &channeldb.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -200,7 +200,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &channeldb.ChannelEdgeInfo{ + edgeInfo: &channeldb.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -221,7 +221,7 @@ func TestManager(t *testing.T) { newPolicy: noMaxHtlcPolicy, channelSet: []channel{ { - edgeInfo: &channeldb.ChannelEdgeInfo{ + edgeInfo: &channeldb.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, diff --git a/routing/notifications.go b/routing/notifications.go index 387bd9126d..85aeb838bb 100644 --- a/routing/notifications.go +++ b/routing/notifications.go @@ -212,7 +212,7 @@ type ClosedChanSummary struct { // createCloseSummaries takes in a slice of channels closed at the target block // height and creates a slice of summaries which of each channel closure. func createCloseSummaries(blockHeight uint32, - closedChans ...*channeldb.ChannelEdgeInfo) []*ClosedChanSummary { + closedChans ...*channeldb.ChannelEdgeInfo1) []*ClosedChanSummary { closeSummaries := make([]*ClosedChanSummary, len(closedChans)) for i, closedChan := range closedChans { @@ -333,7 +333,7 @@ func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, // We ignore initial channel announcements as we'll only send out // updates once the individual edges themselves have been updated. - case *channeldb.ChannelEdgeInfo: + case *channeldb.ChannelEdgeInfo1: return nil // Any new ChannelUpdateAnnouncements will generate a corresponding diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 658df54233..aee5581c57 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -417,7 +417,7 @@ func TestEdgeUpdateNotification(t *testing.T) { // Finally, to conclude our test set up, we'll create a channel // update to announce the created channel between the two nodes. - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -599,7 +599,7 @@ func TestNodeUpdateNotification(t *testing.T) { testFeaturesBuf := new(bytes.Buffer) require.NoError(t, testFeatures.Encode(testFeaturesBuf)) - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -785,7 +785,7 @@ func TestNotificationCancellation(t *testing.T) { // to the client. ntfnClient.Cancel() - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -856,7 +856,7 @@ func TestChannelCloseNotification(t *testing.T) { // Finally, to conclude our test set up, we'll create a channel // announcement to announce the created channel between the two nodes. - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index c016451bd6..f97c116689 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -345,7 +345,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( // We first insert the existence of the edge between the two // nodes. - edgeInfo := channeldb.ChannelEdgeInfo{ + edgeInfo := channeldb.ChannelEdgeInfo1{ ChannelID: edge.ChannelID, AuthProof: &testAuthProof, ChannelPoint: fundingPoint, @@ -657,7 +657,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, // We first insert the existence of the edge between the two // nodes. - edgeInfo := channeldb.ChannelEdgeInfo{ + edgeInfo := channeldb.ChannelEdgeInfo1{ ChannelID: channelID, AuthProof: &testAuthProof, ChannelPoint: *fundingPoint, diff --git a/routing/router.go b/routing/router.go index 6db9da6cdc..f271a3d5dc 100644 --- a/routing/router.go +++ b/routing/router.go @@ -135,7 +135,7 @@ type ChannelGraphSource interface { // AddEdge is used to add edge/channel to the topology of the router, // after all information about channel will be gathered this // edge/channel might be used in construction of payment path. - AddEdge(edge *channeldb.ChannelEdgeInfo, + AddEdge(edge *channeldb.ChannelEdgeInfo1, op ...batch.SchedulerOption) error // AddProof updates the channel edge info with proof which is needed to @@ -176,7 +176,7 @@ type ChannelGraphSource interface { // emanating from the "source" node which is the center of the // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, - c *channeldb.ChannelEdgeInfo, + c *channeldb.ChannelEdgeInfo1, e *channeldb.ChannelEdgePolicy) error) error // CurrentBlockHeight returns the block height from POV of the router @@ -185,7 +185,7 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) // FetchLightningNode attempts to look up a target node by its identity @@ -897,14 +897,14 @@ func (r *ChannelRouter) pruneZombieChans() error { log.Infof("Examining channel graph for zombie channels") // A helper method to detect if the channel belongs to this node - isSelfChannelEdge := func(info *channeldb.ChannelEdgeInfo) bool { + isSelfChannelEdge := func(info *channeldb.ChannelEdgeInfo1) bool { return info.NodeKey1Bytes == r.selfNode.PubKeyBytes || info.NodeKey2Bytes == r.selfNode.PubKeyBytes } // First, we'll collect all the channels which are eligible for garbage // collection due to being zombies. - filterPruneChans := func(info *channeldb.ChannelEdgeInfo, + filterPruneChans := func(info *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) error { // Exit early in case this channel is already marked to be pruned @@ -1539,8 +1539,8 @@ func (r *ChannelRouter) processUpdate(msg interface{}, log.Tracef("Updated vertex data for node=%x", msg.PubKeyBytes) r.stats.incNumNodeUpdates() - case *channeldb.ChannelEdgeInfo: - log.Debugf("Received ChannelEdgeInfo for channel %v", + case *channeldb.ChannelEdgeInfo1: + log.Debugf("Received ChannelEdgeInfo1 for channel %v", msg.ChannelID) // Prior to processing the announcement we first check if we @@ -2720,7 +2720,7 @@ func (r *ChannelRouter) AddNode(node *channeldb.LightningNode, // in construction of payment path. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) AddEdge(edge *channeldb.ChannelEdgeInfo, +func (r *ChannelRouter) AddEdge(edge *channeldb.ChannelEdgeInfo1, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -2787,7 +2787,7 @@ func (r *ChannelRouter) SyncedHeight() uint32 { // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { @@ -2822,10 +2822,10 @@ func (r *ChannelRouter) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy) error) error { + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error { return r.selfNode.ForEachChannel(nil, func(tx kvdb.RTx, - c *channeldb.ChannelEdgeInfo, + c *channeldb.ChannelEdgeInfo1, e, _ *channeldb.ChannelEdgePolicy) error { if e == nil { diff --git a/routing/router_test.go b/routing/router_test.go index ea1fec2870..2b490a851f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1241,7 +1241,7 @@ func TestAddProof(t *testing.T) { ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // After utxo was recreated adding the edge without the proof. - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -1330,7 +1330,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -1408,7 +1408,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -1507,7 +1507,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge = &channeldb.ChannelEdgeInfo{ + edge = &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), AuthProof: nil, } @@ -1710,7 +1710,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { node2, err := createTestNode() require.NoError(t, err, "unable to create test node") - edge1 := &channeldb.ChannelEdgeInfo{ + edge1 := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID1, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -1728,7 +1728,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge2 := &channeldb.ChannelEdgeInfo{ + edge2 := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID2, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -1918,7 +1918,7 @@ func TestDisconnectedBlocks(t *testing.T) { node2, err := createTestNode() require.NoError(t, err, "unable to create test node") - edge1 := &channeldb.ChannelEdgeInfo{ + edge1 := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID1, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -1938,7 +1938,7 @@ func TestDisconnectedBlocks(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge2 := &channeldb.ChannelEdgeInfo{ + edge2 := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID2, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -2069,7 +2069,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { require.NoError(t, err, "unable to create test node") node2, err := createTestNode() require.NoError(t, err, "unable to create test node") - edge1 := &channeldb.ChannelEdgeInfo{ + edge1 := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID1.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -2508,7 +2508,7 @@ func TestIsStaleNode(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -2584,7 +2584,7 @@ func TestIsKnownEdge(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -2640,7 +2640,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { t.Fatalf("router failed to detect fresh edge policy") } - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -3265,7 +3265,7 @@ const ( // newChannelEdgeInfo is a helper function used to create a new channel edge, // possibly skipping adding it to parts of the chain/state as well. func newChannelEdgeInfo(ctx *testCtx, fundingHeight uint32, - ecm edgeCreationModifier) (*channeldb.ChannelEdgeInfo, error) { + ecm edgeCreationModifier) (*channeldb.ChannelEdgeInfo1, error) { node1, err := createTestNode() if err != nil { @@ -3284,7 +3284,7 @@ func newChannelEdgeInfo(ctx *testCtx, fundingHeight uint32, return nil, fmt.Errorf("unable to create edge: %w", err) } - edge := &channeldb.ChannelEdgeInfo{ + edge := &channeldb.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -3315,7 +3315,7 @@ func newChannelEdgeInfo(ctx *testCtx, fundingHeight uint32, } func assertChanChainRejection(t *testing.T, ctx *testCtx, - edge *channeldb.ChannelEdgeInfo, failCode errorCode) { + edge *channeldb.ChannelEdgeInfo1, failCode errorCode) { t.Helper() diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index 5f898c2242..21e148fcfa 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -125,7 +125,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = signals v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = signals } - case *channeldb.ChannelEdgeInfo: + case *channeldb.ChannelEdgeInfo1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) if _, ok := v.chanAnnFinSignal[shortID]; !ok { @@ -217,7 +217,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { // return directly. case *lnwire.AnnounceSignatures: // TODO(roasbeef): need to wait on chan ann? - case *channeldb.ChannelEdgeInfo: + case *channeldb.ChannelEdgeInfo1: case *lnwire.ChannelAnnouncement1: } @@ -263,7 +263,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { // If we've just finished executing a ChannelAnnouncement1, then we'll // close out the signal, and remove the signal from the map of active // ones. This will allow/deny any dependent jobs to continue execution. - case *channeldb.ChannelEdgeInfo: + case *channeldb.ChannelEdgeInfo1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) finSignals, ok := v.chanAnnFinSignal[shortID] if ok { diff --git a/rpcserver.go b/rpcserver.go index 85a156cf9b..bb4f2f4b53 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5921,7 +5921,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // Next, for each active channel we know of within the graph, create a // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. - err = graph.ForEachChannel(func(edgeInfo *channeldb.ChannelEdgeInfo, + err = graph.ForEachChannel(func(edgeInfo *channeldb.ChannelEdgeInfo1, c1, c2 *channeldb.ChannelEdgePolicy) error { // Do not include unannounced channels unless specifically @@ -5977,7 +5977,7 @@ func marshalExtraOpaqueData(data []byte) map[uint64][]byte { return records } -func marshalDbEdge(edgeInfo *channeldb.ChannelEdgeInfo, +func marshalDbEdge(edgeInfo *channeldb.ChannelEdgeInfo1, c1, c2 *channeldb.ChannelEdgePolicy) *lnrpc.ChannelEdge { // Make sure the policies match the node they belong to. c1 should point @@ -6152,7 +6152,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, ) if err := node.ForEachChannel(nil, func(_ kvdb.RTx, - edge *channeldb.ChannelEdgeInfo, + edge *channeldb.ChannelEdgeInfo1, c1, c2 *channeldb.ChannelEdgePolicy) error { numChannels++ @@ -6763,7 +6763,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, } var feeReports []*lnrpc.ChannelFeeReport - err = selfNode.ForEachChannel(nil, func(_ kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo, + err = selfNode.ForEachChannel(nil, func(_ kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, edgePolicy, _ *channeldb.ChannelEdgePolicy) error { // Self node should always have policies for its channels. diff --git a/server.go b/server.go index c47f79757f..4ff46499d4 100644 --- a/server.go +++ b/server.go @@ -3096,7 +3096,7 @@ func (s *server) establishPersistentConnections() error { selfPub := s.identityECDH.PubKey().SerializeCompressed() err = sourceNode.ForEachChannel(nil, func( tx kvdb.RTx, - chanInfo *channeldb.ChannelEdgeInfo, + chanInfo *channeldb.ChannelEdgeInfo1, policy, _ *channeldb.ChannelEdgePolicy) error { // If the remote party has announced the channel to us, but we From 04735a059b22df1d70962ca6f8d4f05fbefac6d4 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 08:38:05 +0200 Subject: [PATCH 14/33] multi: rename ChannelAuthProof to ChannelAuthProof1 This is in preparation for the addition of a ChannelAuthProof interface which will be implemented by ChannelAuthProof1 and the coming ChannelAuthProof2. --- channeldb/graph.go | 18 +++++++++--------- channeldb/graph_test.go | 10 +++++----- discovery/gossiper.go | 8 ++++---- discovery/gossiper_test.go | 2 +- netann/channel_announcement.go | 2 +- netann/channel_announcement_test.go | 2 +- routing/notifications_test.go | 8 ++++---- routing/pathfind_test.go | 2 +- routing/router.go | 4 ++-- routing/router_test.go | 10 +++++----- 10 files changed, 33 insertions(+), 33 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 8f6dfb2382..b372dacd0b 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -3068,7 +3068,7 @@ type ChannelEdgeInfo1 struct { // AuthProof is the authentication proof for this channel. This proof // contains a set of signatures binding four identities, which attests // to the legitimacy of the advertised channel. - AuthProof *ChannelAuthProof + AuthProof *ChannelAuthProof1 // ChannelPoint is the funding outpoint of the channel. This can be // used to uniquely identify the channel within the channel graph. @@ -3255,14 +3255,14 @@ func (c *ChannelEdgeInfo1) FetchOtherNode(tx kvdb.RTx, return targetNode, err } -// ChannelAuthProof is the authentication proof (the signature portion) for a +// ChannelAuthProof1 is the authentication proof (the signature portion) for a // channel. Using the four signatures contained in the struct, and some // auxiliary knowledge (the funding script, node identities, and outpoint) nodes // on the network are able to validate the authenticity and existence of a // channel. Each of these signatures signs the following digest: chanID || // nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || // features. -type ChannelAuthProof struct { +type ChannelAuthProof1 struct { // nodeSig1 is a cached instance of the first node signature. nodeSig1 *ecdsa.Signature @@ -3298,7 +3298,7 @@ type ChannelAuthProof struct { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node1Sig() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) Node1Sig() (*ecdsa.Signature, error) { if c.nodeSig1 != nil { return c.nodeSig1, nil } @@ -3319,7 +3319,7 @@ func (c *ChannelAuthProof) Node1Sig() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node2Sig() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) Node2Sig() (*ecdsa.Signature, error) { if c.nodeSig2 != nil { return c.nodeSig2, nil } @@ -3339,7 +3339,7 @@ func (c *ChannelAuthProof) Node2Sig() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig1() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) BitcoinSig1() (*ecdsa.Signature, error) { if c.bitcoinSig1 != nil { return c.bitcoinSig1, nil } @@ -3359,7 +3359,7 @@ func (c *ChannelAuthProof) BitcoinSig1() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig2() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) BitcoinSig2() (*ecdsa.Signature, error) { if c.bitcoinSig2 != nil { return c.bitcoinSig2, nil } @@ -3376,7 +3376,7 @@ func (c *ChannelAuthProof) BitcoinSig2() (*ecdsa.Signature, error) { // IsEmpty check is the authentication proof is empty Proof is empty if at // least one of the signatures are equal to nil. -func (c *ChannelAuthProof) IsEmpty() bool { +func (c *ChannelAuthProof1) IsEmpty() bool { return len(c.NodeSig1Bytes) == 0 || len(c.NodeSig2Bytes) == 0 || len(c.BitcoinSig1Bytes) == 0 || @@ -4411,7 +4411,7 @@ func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo1, error) { return ChannelEdgeInfo1{}, err } - proof := &ChannelAuthProof{} + proof := &ChannelAuthProof1{} proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 5685f6f2a0..2cdce5f2e9 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -330,7 +330,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &ChannelAuthProof{ + AuthProof: &ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -404,7 +404,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, edgeInfo := ChannelEdgeInfo1{ ChannelID: shortChanID.ToUint64(), ChainHash: key, - AuthProof: &ChannelAuthProof{ + AuthProof: &ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -646,7 +646,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( edgeInfo := &ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &ChannelAuthProof{ + AuthProof: &ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1252,7 +1252,7 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &ChannelAuthProof{ + AuthProof: &ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1433,7 +1433,7 @@ func TestGraphPruning(t *testing.T) { edgeInfo := ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &ChannelAuthProof{ + AuthProof: &ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), diff --git a/discovery/gossiper.go b/discovery/gossiper.go index e267c9d404..4891f916de 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -1813,7 +1813,7 @@ func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo1, // assemble the proof and craft the ChannelAnnouncement1. func (d *AuthenticatedGossiper) processRejectedEdge( chanAnnMsg *lnwire.ChannelAnnouncement1, - proof *channeldb.ChannelAuthProof) ([]networkMsg, error) { + proof *channeldb.ChannelAuthProof1) ([]networkMsg, error) { // First, we'll fetch the state of the channel as we know if from the // database. @@ -2428,7 +2428,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If this is a remote channel announcement, then we'll validate all // the signatures within the proof as it should be well formed. - var proof *channeldb.ChannelAuthProof + var proof *channeldb.ChannelAuthProof1 if nMsg.isRemote { if err := routing.ValidateChannelAnn(ann); err != nil { err := fmt.Errorf("unable to validate announcement: "+ @@ -2448,7 +2448,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the proof checks out, then we'll save the proof itself to // the database so we can fetch it later when gossiping with // other nodes. - proof = &channeldb.ChannelAuthProof{ + proof = &channeldb.ChannelAuthProof1{ NodeSig1Bytes: ann.NodeSig1.ToSignatureBytes(), NodeSig2Bytes: ann.NodeSig2.ToSignatureBytes(), BitcoinSig1Bytes: ann.BitcoinSig1.ToSignatureBytes(), @@ -3217,7 +3217,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // We now have both halves of the channel announcement proof, then // we'll reconstruct the initial announcement so we can validate it // shortly below. - var dbProof channeldb.ChannelAuthProof + var dbProof channeldb.ChannelAuthProof1 if isFirstNode { dbProof.NodeSig1Bytes = ann.NodeSignature.ToSignatureBytes() dbProof.NodeSig2Bytes = oppProof.NodeSignature.ToSignatureBytes() diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index b3313732ca..26a7015657 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -168,7 +168,7 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) { } func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof) error { + proof *channeldb.ChannelAuthProof1) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 7ee3ba05ac..9cce4e34e4 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -12,7 +12,7 @@ import ( // function is used to transform out database structs into the corresponding wire // structs for announcing new channels to other peers, or simply syncing up a // peer's initial routing table upon connect. -func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof, +func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index d8126c9cc0..9287b9b76d 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -39,7 +39,7 @@ func TestCreateChanAnnouncement(t *testing.T) { ExtraOpaqueData: []byte{0x1}, } - chanProof := &channeldb.ChannelAuthProof{ + chanProof := &channeldb.ChannelAuthProof1{ NodeSig1Bytes: expChanAnn.NodeSig1.ToSignatureBytes(), NodeSig2Bytes: expChanAnn.NodeSig2.ToSignatureBytes(), BitcoinSig1Bytes: expChanAnn.BitcoinSig1.ToSignatureBytes(), diff --git a/routing/notifications_test.go b/routing/notifications_test.go index aee5581c57..63f03923ff 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -421,7 +421,7 @@ func TestEdgeUpdateNotification(t *testing.T) { ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -603,7 +603,7 @@ func TestNodeUpdateNotification(t *testing.T) { ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -789,7 +789,7 @@ func TestNotificationCancellation(t *testing.T) { ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -860,7 +860,7 @@ func TestChannelCloseNotification(t *testing.T) { ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index f97c116689..3731f3d9e0 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -96,7 +96,7 @@ var ( _ = testSScalar.SetByteSlice(testSBytes) testSig = ecdsa.NewSignature(testRScalar, testSScalar) - testAuthProof = channeldb.ChannelAuthProof{ + testAuthProof = channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), diff --git a/routing/router.go b/routing/router.go index f271a3d5dc..783e2139da 100644 --- a/routing/router.go +++ b/routing/router.go @@ -141,7 +141,7 @@ type ChannelGraphSource interface { // AddProof updates the channel edge info with proof which is needed to // properly announce the edge to the rest of the network. AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof) error + proof *channeldb.ChannelAuthProof1) error // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. @@ -2841,7 +2841,7 @@ func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof) error { + proof *channeldb.ChannelAuthProof1) error { info, _, _, err := r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) if err != nil { diff --git a/routing/router_test.go b/routing/router_test.go index 2b490a851f..614f347767 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1714,7 +1714,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { ChannelID: chanID1, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1732,7 +1732,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { ChannelID: chanID2, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1924,7 +1924,7 @@ func TestDisconnectedBlocks(t *testing.T) { NodeKey2Bytes: node2.PubKeyBytes, BitcoinKey1Bytes: node1.PubKeyBytes, BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1944,7 +1944,7 @@ func TestDisconnectedBlocks(t *testing.T) { NodeKey2Bytes: node2.PubKeyBytes, BitcoinKey1Bytes: node1.PubKeyBytes, BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -2073,7 +2073,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { ChannelID: chanID1.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &channeldb.ChannelAuthProof{ + AuthProof: &channeldb.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), From 4a138334f07d0e3506f4866aac406fa72b9b8d45 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 09:15:58 +0200 Subject: [PATCH 15/33] lnwire: add the ChannelAnnouncement interface --- ...ouncement_1.go => channel_announcement.go} | 35 +++++++++++++++++++ lnwire/channel_announcement_2.go | 35 +++++++++++++++++++ lnwire/interfaces.go | 24 +++++++++++++ 3 files changed, 94 insertions(+) rename lnwire/{channel_announcement_1.go => channel_announcement.go} (82%) create mode 100644 lnwire/interfaces.go diff --git a/lnwire/channel_announcement_1.go b/lnwire/channel_announcement.go similarity index 82% rename from lnwire/channel_announcement_1.go rename to lnwire/channel_announcement.go index 86d3335614..aca1e1119e 100644 --- a/lnwire/channel_announcement_1.go +++ b/lnwire/channel_announcement.go @@ -184,3 +184,38 @@ func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// Node1KeyBytes returns the bytes representing the public key of node 1 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) Node1KeyBytes() [33]byte { + return a.NodeID1 +} + +// Node2KeyBytes returns the bytes representing the public key of node 2 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) Node2KeyBytes() [33]byte { + return a.NodeID2 +} + +// GetChainHash returns the hash of the chain which this channel's funding +// transaction is confirmed in. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) GetChainHash() chainhash.Hash { + return a.ChainHash +} + +// SCID returns the short channel ID of the channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) SCID() ShortChannelID { + return a.ShortChannelID +} + +// A compile-time check to ensure that ChannelAnnouncement1 implements the +// ChannelAnnouncement interface. +var _ ChannelAnnouncement = (*ChannelAnnouncement1)(nil) diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 2fcf05bfd6..68d2ae3bd7 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -266,3 +266,38 @@ func (c *ChannelAnnouncement2) MsgType() MessageType { // A compile time check to ensure ChannelAnnouncement2 implements the // lnwire.Message interface. var _ Message = (*ChannelAnnouncement2)(nil) + +// Node1KeyBytes returns the bytes representing the public key of node 1 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) Node1KeyBytes() [33]byte { + return c.NodeID1 +} + +// Node2KeyBytes returns the bytes representing the public key of node 2 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) Node2KeyBytes() [33]byte { + return c.NodeID2 +} + +// GetChainHash returns the hash of the chain which this channel's funding +// transaction is confirmed in. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) GetChainHash() chainhash.Hash { + return c.ChainHash +} + +// SCID returns the short channel ID of the channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) SCID() ShortChannelID { + return c.ShortChannelID +} + +// A compile-time check to ensure that ChannelAnnouncement2 implements the +// ChannelAnnouncement interface. +var _ ChannelAnnouncement = (*ChannelAnnouncement2)(nil) diff --git a/lnwire/interfaces.go b/lnwire/interfaces.go new file mode 100644 index 0000000000..d50fad3862 --- /dev/null +++ b/lnwire/interfaces.go @@ -0,0 +1,24 @@ +package lnwire + +import "github.com/btcsuite/btcd/chaincfg/chainhash" + +// ChannelAnnouncement is an interface that must be satisfied by any message +// used to announce and prove the existence of a channel. +type ChannelAnnouncement interface { + // SCID returns the short channel ID of the channel. + SCID() ShortChannelID + + // GetChainHash returns the hash of the chain which this channel's + // funding transaction is confirmed in. + GetChainHash() chainhash.Hash + + // Node1KeyBytes returns the bytes representing the public key of node + // 1 in the channel. + Node1KeyBytes() [33]byte + + // Node2KeyBytes returns the bytes representing the public key of node + // 2 in the channel. + Node2KeyBytes() [33]byte + + Message +} From 04d9eed68c080ca419a6c34ababc235d855e23c7 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 11:36:45 +0200 Subject: [PATCH 16/33] multi: use ChannelAnnouncement interface where possible --- discovery/gossiper.go | 219 ++++++++++++++++++------------- discovery/syncer.go | 16 +-- funding/manager.go | 2 +- lnwire/channel_announcement_2.go | 31 ++++- lnwire/msg_hash.go | 13 ++ netann/channel_announcement.go | 2 +- peer/brontide.go | 6 +- routing/ann_validation.go | 71 +++++++++- routing/validation_barrier.go | 25 ++-- 9 files changed, 266 insertions(+), 119 deletions(-) create mode 100644 lnwire/msg_hash.go diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 4891f916de..089e5319e4 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -818,13 +818,18 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // To avoid inserting edges in the graph for our own channels that we // have already closed, we ignore such channel announcements coming // from the remote. - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: ownKey := d.selfKey.SerializeCompressed() - ownErr := fmt.Errorf("ignoring remote ChannelAnnouncement1 " + - "for own channel") + ownErr := fmt.Errorf("ignoring remote %s for own channel", + m.MsgType()) + + var ( + node1ID = m.Node1KeyBytes() + node2ID = m.Node2KeyBytes() + ) - if bytes.Equal(m.NodeID1[:], ownKey) || - bytes.Equal(m.NodeID2[:], ownKey) { + if bytes.Equal(node1ID[:], ownKey) || + bytes.Equal(node2ID[:], ownKey) { log.Warn(ownErr) errChan <- ownErr @@ -980,8 +985,8 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { switch msg := message.msg.(type) { // Channel announcements are identified by the short channel id field. - case *lnwire.ChannelAnnouncement1: - deDupKey := msg.ShortChannelID + case lnwire.ChannelAnnouncement: + deDupKey := msg.SCID() sender := route.NewVertex(message.source) mws, ok := d.channelAnnouncements[deDupKey] @@ -1554,8 +1559,8 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, case *lnwire.ChannelUpdate: scid = m.ShortChannelID.ToUint64() - case *lnwire.ChannelAnnouncement1: - scid = m.ShortChannelID.ToUint64() + case lnwire.ChannelAnnouncement: + scid = m.SCID().ToUint64() default: return false @@ -1812,14 +1817,14 @@ func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo1, // to receive the remote peer's proof, while the remote peer is able to fully // assemble the proof and craft the ChannelAnnouncement1. func (d *AuthenticatedGossiper) processRejectedEdge( - chanAnnMsg *lnwire.ChannelAnnouncement1, + chanAnnMsg lnwire.ChannelAnnouncement, proof *channeldb.ChannelAuthProof1) ([]networkMsg, error) { + scid := chanAnnMsg.SCID() + // First, we'll fetch the state of the channel as we know if from the // database. - chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID( - chanAnnMsg.ShortChannelID, - ) + chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID(scid) if err != nil { return nil, err } @@ -1849,18 +1854,17 @@ func (d *AuthenticatedGossiper) processRejectedEdge( err = routing.ValidateChannelAnn(chanAnn) if err != nil { err := fmt.Errorf("assembled channel announcement proof "+ - "for shortChanID=%v isn't valid: %v", - chanAnnMsg.ShortChannelID, err) + "for shortChanID=%v isn't valid: %v", scid, err) log.Error(err) return nil, err } // If everything checks out, then we'll add the fully assembled proof // to the database. - err = d.cfg.Router.AddProof(chanAnnMsg.ShortChannelID, proof) + err = d.cfg.Router.AddProof(scid, proof) if err != nil { - err := fmt.Errorf("unable add proof to shortChanID=%v: %v", - chanAnnMsg.ShortChannelID, err) + err := fmt.Errorf("unable add proof to shortChanID=%v: %w", + scid, err) log.Error(err) return nil, err } @@ -1997,7 +2001,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // *creation* of a new channel within the network. This only advertises // the existence of a channel and not yet the routing policies in // either direction of the channel. - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: return d.handleChanAnnouncement(nMsg, msg, schedulerOp) // A new authenticated channel edge update has arrived. This indicates @@ -2151,7 +2155,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, - edge *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, + edge *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate, error) { // Parse the unsigned edge into a channel update. @@ -2174,7 +2178,9 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. - err = routing.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate) + err = routing.ValidateChannelUpdateAnn( + d.selfKey, info.Capacity, chanUpdate, + ) if err != nil { return nil, nil, fmt.Errorf("generated invalid channel "+ "update sig: %v", err) @@ -2363,23 +2369,26 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, // handleChanAnnouncement processes a new channel announcement. func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, - ann *lnwire.ChannelAnnouncement1, + ann lnwire.ChannelAnnouncement, ops []batch.SchedulerOption) ([]networkMsg, bool) { + var ( + scid = ann.SCID() + chainHash = ann.GetChainHash() + ) + log.Debugf("Processing ChannelAnnouncement1: peer=%v, short_chan_id=%v", - nMsg.peer, ann.ShortChannelID.ToUint64()) + nMsg.peer, scid.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. - if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { - err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ - ", gossiper on chain=%v", ann.ChainHash, - d.cfg.ChainHash) + if !bytes.Equal(chainHash[:], d.cfg.ChainHash[:]) { + err := fmt.Errorf("ignoring %s from chain=%v, gossiper on "+ + "chain=%v", ann.MsgType(), chainHash, d.cfg.ChainHash) log.Errorf(err.Error()) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), - sourceToPub(nMsg.source), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2390,14 +2399,12 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If this is a remote ChannelAnnouncement1 with an alias SCID, we'll // reject the announcement. Since the router accepts alias SCIDs, // not erroring out would be a DoS vector. - if nMsg.isRemote && d.cfg.IsAlias(ann.ShortChannelID) { - err := fmt.Errorf("ignoring remote alias channel=%v", - ann.ShortChannelID) + if nMsg.isRemote && d.cfg.IsAlias(scid) { + err := fmt.Errorf("ignoring remote alias channel=%v", scid) log.Errorf(err.Error()) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), - sourceToPub(nMsg.source), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2408,11 +2415,10 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the advertised inclusionary block is beyond our knowledge of the // chain tip, then we'll ignore it for now. d.Lock() - if nMsg.isRemote && d.isPremature(ann.ShortChannelID, 0, nMsg) { + if nMsg.isRemote && d.isPremature(scid, 0, nMsg) { log.Warnf("Announcement for chan_id=(%v), is premature: "+ "advertises height %v, only height %v is known", - ann.ShortChannelID.ToUint64(), - ann.ShortChannelID.BlockHeight, d.bestHeight) + scid.ToUint64(), scid.BlockHeight, d.bestHeight) d.Unlock() nMsg.err <- nil return nil, false @@ -2421,7 +2427,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // At this point, we'll now ask the router if this is a zombie/known // edge. If so we can skip all the processing below. - if d.cfg.Router.IsKnownEdge(ann.ShortChannelID) { + if d.cfg.Router.IsKnownEdge(scid) { nMsg.err <- nil return nil, true } @@ -2435,8 +2441,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, "%v", err) key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), - sourceToPub(nMsg.source), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2448,49 +2453,30 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the proof checks out, then we'll save the proof itself to // the database so we can fetch it later when gossiping with // other nodes. - proof = &channeldb.ChannelAuthProof1{ - NodeSig1Bytes: ann.NodeSig1.ToSignatureBytes(), - NodeSig2Bytes: ann.NodeSig2.ToSignatureBytes(), - BitcoinSig1Bytes: ann.BitcoinSig1.ToSignatureBytes(), - BitcoinSig2Bytes: ann.BitcoinSig2.ToSignatureBytes(), + var err error + proof, err = buildChanProof(ann) + if err != nil { + err := fmt.Errorf("unable to build channel "+ + "announcement proof: %v", err) + log.Error(err) + nMsg.err <- err + + return nil, false } } // With the proof validated (if necessary), we can now store it within // the database for our path finding and syncing needs. - var featureBuf bytes.Buffer - if err := ann.Features.Encode(&featureBuf); err != nil { - log.Errorf("unable to encode features: %v", err) + edge, err := buildEdgeInfo(ann, nMsg.optionalMsgFields, proof) + if err != nil { + log.Errorf("unable to build edge info from announcement: %v", + err) nMsg.err <- err - return nil, false - } - - edge := &channeldb.ChannelEdgeInfo1{ - ChannelID: ann.ShortChannelID.ToUint64(), - ChainHash: ann.ChainHash, - NodeKey1Bytes: ann.NodeID1, - NodeKey2Bytes: ann.NodeID2, - BitcoinKey1Bytes: ann.BitcoinKey1, - BitcoinKey2Bytes: ann.BitcoinKey2, - AuthProof: proof, - Features: featureBuf.Bytes(), - ExtraOpaqueData: ann.ExtraOpaqueData, - } - // If there were any optional message fields provided, we'll include - // them in its serialized disk representation now. - if nMsg.optionalMsgFields != nil { - if nMsg.optionalMsgFields.capacity != nil { - edge.Capacity = *nMsg.optionalMsgFields.capacity - } - if nMsg.optionalMsgFields.channelPoint != nil { - cp := *nMsg.optionalMsgFields.channelPoint - edge.ChannelPoint = cp - } + return nil, false } - log.Debugf("Adding edge for short_chan_id: %v", - ann.ShortChannelID.ToUint64()) + log.Debugf("Adding edge for short_chan_id: %v", scid.ToUint64()) // We will add the edge to the channel router. If the nodes present in // this channel are not present in the database, a partial node will be @@ -2500,13 +2486,13 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // channel ID. We do this to ensure no other goroutine has read the // database and is now making decisions based on this DB state, before // it writes to the DB. - d.channelMtx.Lock(ann.ShortChannelID.ToUint64()) - err := d.cfg.Router.AddEdge(edge, ops...) + d.channelMtx.Lock(scid.ToUint64()) + err = d.cfg.Router.AddEdge(edge, ops...) if err != nil { log.Debugf("Router rejected edge for short_chan_id(%v): %v", - ann.ShortChannelID.ToUint64(), err) + scid.ToUint64(), err) - defer d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) + defer d.channelMtx.Unlock(scid.ToUint64()) // If the edge was rejected due to already being known, then it // may be the case that this new message has a fresh channel @@ -2517,7 +2503,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, anns, rErr := d.processRejectedEdge(ann, proof) if rErr != nil { key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), + scid.ToUint64(), sourceToPub(nMsg.source), ) cr := &cachedReject{} @@ -2542,8 +2528,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } else { // Otherwise, this is just a regular rejected edge. key := newRejectCacheKey( - ann.ShortChannelID.ToUint64(), - sourceToPub(nMsg.source), + scid.ToUint64(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) } @@ -2553,17 +2538,15 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } // If err is nil, release the lock immediately. - d.channelMtx.Unlock(ann.ShortChannelID.ToUint64()) + d.channelMtx.Unlock(scid.ToUint64()) - log.Debugf("Finish adding edge for short_chan_id: %v", - ann.ShortChannelID.ToUint64()) + log.Debugf("Finish adding edge for short_chan_id: %v", scid.ToUint64()) // If we earlier received any ChannelUpdates for this channel, we can // now process them, as the channel is added to the graph. - shortChanID := ann.ShortChannelID.ToUint64() var channelUpdates []*processedNetworkMsg - earlyChanUpdates, err := d.prematureChannelUpdates.Get(shortChanID) + earlyChanUpdates, err := d.prematureChannelUpdates.Get(scid.ToUint64()) if err == nil { // There was actually an entry in the map, so we'll accumulate // it. We don't worry about deletion, since it'll eventually @@ -2630,8 +2613,8 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil - log.Debugf("Processed ChannelAnnouncement1: peer=%v, short_chan_id=%v", - nMsg.peer, ann.ShortChannelID.ToUint64()) + log.Debugf("Processed %s: peer=%v, short_chan_id=%v", + nMsg.msg.MsgType(), nMsg.peer, scid.ToUint64()) return announcements, true } @@ -3340,3 +3323,63 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, nMsg.err <- nil return announcements, true } + +func buildChanProof(ann lnwire.ChannelAnnouncement) ( + *channeldb.ChannelAuthProof1, error) { + + switch a := ann.(type) { + case *lnwire.ChannelAnnouncement1: + return &channeldb.ChannelAuthProof1{ + NodeSig1Bytes: a.NodeSig1.ToSignatureBytes(), + NodeSig2Bytes: a.NodeSig2.ToSignatureBytes(), + BitcoinSig1Bytes: a.BitcoinSig1.ToSignatureBytes(), + BitcoinSig2Bytes: a.BitcoinSig2.ToSignatureBytes(), + }, nil + default: + return nil, fmt.Errorf("unhandled lnwire.ChannelAnnouncement "+ + "implementation: %T", a) + } +} + +func buildEdgeInfo(ann lnwire.ChannelAnnouncement, opts *optionalMsgFields, + proof *channeldb.ChannelAuthProof1) (*channeldb.ChannelEdgeInfo1, + error) { + + switch a := ann.(type) { + case *lnwire.ChannelAnnouncement1: + var featureBuf bytes.Buffer + if err := a.Features.Encode(&featureBuf); err != nil { + return nil, err + } + + edge := &channeldb.ChannelEdgeInfo1{ + ChannelID: a.ShortChannelID.ToUint64(), + ChainHash: a.ChainHash, + NodeKey1Bytes: a.NodeID1, + NodeKey2Bytes: a.NodeID2, + BitcoinKey1Bytes: a.BitcoinKey1, + BitcoinKey2Bytes: a.BitcoinKey2, + Features: featureBuf.Bytes(), + AuthProof: proof, + ExtraOpaqueData: a.ExtraOpaqueData, + } + + // If there were any optional message fields provided, we'll + // include them in its serialized disk representation now. + if opts != nil { + if opts.capacity != nil { + edge.Capacity = *opts.capacity + } + if opts.channelPoint != nil { + cp := *opts.channelPoint + edge.ChannelPoint = cp + } + } + + return edge, nil + + default: + return nil, fmt.Errorf("unhandled lnwire.ChannelAnnouncement "+ + "implementation: %T", a) + } +} diff --git a/discovery/syncer.go b/discovery/syncer.go index 72174ea666..c8a727b384 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1314,20 +1314,20 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // For each channel announcement message, we'll only send this // message if the channel updates for the channel are between // our time range. - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: + scid := msg.SCID() + // First, we'll check if the channel updates are in // this message batch. - chanUpdates, ok := chanUpdateIndex[msg.ShortChannelID] + chanUpdates, ok := chanUpdateIndex[scid] if !ok { // If not, we'll attempt to query the database // to see if we know of the updates. - chanUpdates, err = g.cfg.channelSeries.FetchChanUpdates( - g.cfg.chainHash, msg.ShortChannelID, - ) + chanUpdates, err = g.cfg.channelSeries. + FetchChanUpdates(g.cfg.chainHash, scid) if err != nil { - log.Warnf("no channel updates found for "+ - "short_chan_id=%v", - msg.ShortChannelID) + log.Warnf("no channel updates found "+ + "for short_chan_id=%v", scid) continue } } diff --git a/funding/manager.go b/funding/manager.go index 3b158af901..c82d81d321 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -4050,7 +4050,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // chanAnnouncement encapsulates the two authenticated announcements that we // send out to the network after a new channel has been created locally. type chanAnnouncement struct { - chanAnn *lnwire.ChannelAnnouncement1 + chanAnn lnwire.ChannelAnnouncement chanUpdateAnn *lnwire.ChannelUpdate chanProof *lnwire.AnnounceSignatures } diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 68d2ae3bd7..3cfed664ca 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -193,15 +193,36 @@ func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { // // This is part of the lnwire.Message interface. func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { - var records []tlv.Record - _, err := w.Write(c.Signature.RawBytes()) if err != nil { return err } + _, err = c.DataToSign() + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// DigestToSign computes the digest of the message to be signed. +func (c *ChannelAnnouncement2) DigestToSign() (*chainhash.Hash, error) { + data, err := c.DataToSign() + if err != nil { + return nil, err + } + hash := MsgHash( + "channel_announcement_2", "announcement_signature", data, + ) + + return hash, nil +} + +func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. + var records []tlv.Record if !c.ChainHash.IsEqual(chaincfg.MainNetParams.GenesisHash) { chainHash := [32]byte(c.ChainHash) records = append(records, tlv.MakePrimitiveRecord( @@ -247,12 +268,12 @@ func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { } } - err = EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) + err := EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) if err != nil { - return err + return nil, err } - return WriteBytes(w, c.ExtraOpaqueData) + return c.ExtraOpaqueData, nil } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/msg_hash.go b/lnwire/msg_hash.go new file mode 100644 index 0000000000..a3f05b8db5 --- /dev/null +++ b/lnwire/msg_hash.go @@ -0,0 +1,13 @@ +package lnwire + +import "github.com/btcsuite/btcd/chaincfg/chainhash" + +const MsgHashTag = "lightning" + +func MsgHash(msgName, fieldName string, msg []byte) *chainhash.Hash { + tag := []byte(MsgHashTag) + tag = append(tag, []byte(msgName)...) + tag = append(tag, []byte(fieldName)...) + + return chainhash.TaggedHash(tag, msg) +} diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 9cce4e34e4..e383286990 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -14,7 +14,7 @@ import ( // peer's initial routing table upon connect. func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, - e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, + e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { // First, using the parameters of the channel, along with the channel diff --git a/peer/brontide.go b/peer/brontide.go index dc05b2ef16..349cdab794 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1725,7 +1725,7 @@ out: } case *lnwire.ChannelUpdate, - *lnwire.ChannelAnnouncement1, + lnwire.ChannelAnnouncement, *lnwire.NodeAnnouncement, *lnwire.AnnounceSignatures, *lnwire.GossipTimestampRange, @@ -1978,9 +1978,9 @@ func messageSummary(msg lnwire.Message) string { return fmt.Sprintf("chan_id=%v, short_chan_id=%v", msg.ChannelID, msg.ShortChannelID.ToUint64()) - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v", - msg.ChainHash, msg.ShortChannelID.ToUint64()) + msg.GetChainHash(), msg.SCID().ToUint64()) case *lnwire.ChannelUpdate: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+ diff --git a/routing/ann_validation.go b/routing/ann_validation.go index a60944bf8a..5e3b7616c4 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" @@ -12,10 +13,24 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) -// ValidateChannelAnn validates the channel announcement message and checks +// ValidateChannelAnn validates the signature(s) of a channel announcement +// message. +func ValidateChannelAnn(a lnwire.ChannelAnnouncement) error { + switch ann := a.(type) { + case *lnwire.ChannelAnnouncement1: + return validateChannelAnn1(ann) + case *lnwire.ChannelAnnouncement2: + return validateChannelAnn2(ann) + default: + return fmt.Errorf("unhandled lnwire.ChannelAnnouncement "+ + "implementation: %T", ann) + } +} + +// validateChannelAnn1 validates the channel announcement message and checks // that node signatures covers the announcement message, and that the bitcoin // signatures covers the node keys. -func ValidateChannelAnn(a *lnwire.ChannelAnnouncement1) error { +func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error { // First, we'll compute the digest (h) which is to be signed by each of // the keys included within the node announcement message. This hash // digest includes all the keys, so the (up to 4 signatures) will @@ -82,7 +97,59 @@ func ValidateChannelAnn(a *lnwire.ChannelAnnouncement1) error { } return nil +} + +// validateChannelAnn2 validates the channel announcement 2 message and checks +// that schnorr signature is valid. +func validateChannelAnn2(a *lnwire.ChannelAnnouncement2) error { + dataHash, err := a.DigestToSign() + if err != nil { + return err + } + + sig, err := a.Signature.ToSignature() + if err != nil { + return err + } + + nodeKey1, err := btcec.ParsePubKey(a.NodeID1[:]) + if err != nil { + return err + } + nodeKey2, err := btcec.ParsePubKey(a.NodeID2[:]) + if err != nil { + return err + } + + keys := []*btcec.PublicKey{ + nodeKey1, nodeKey2, + } + + if a.BitcoinKey1 != nil && a.BitcoinKey2 != nil { + bitcoinKey1, err := btcec.ParsePubKey(a.BitcoinKey1[:]) + if err != nil { + return err + } + + bitcoinKey2, err := btcec.ParsePubKey(a.BitcoinKey2[:]) + if err != nil { + return err + } + + keys = append(keys, bitcoinKey1, bitcoinKey2) + } + + aggKey, _, _, err := musig2.AggregateKeys(keys, true) + if err != nil { + return err + } + + if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { + return fmt.Errorf("invalid sig") + } + + return nil } // ValidateNodeAnn validates the node announcement by ensuring that the diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index 21e148fcfa..ef3109d25b 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -101,7 +101,8 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // ChannelUpdates for the same channel, or NodeAnnouncements of nodes // that are involved in this channel. This goes for both the wire // type,s and also the types that we use within the database. - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: + scid := msg.SCID() // We ensure that we only create a new announcement signal iff, // one doesn't already exist, as there may be duplicate @@ -109,7 +110,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // ChannelAnnouncement1 has been validated. This will result in // all the dependent jobs being unlocked so they can finish // execution themselves. - if _, ok := v.chanAnnFinSignal[msg.ShortChannelID]; !ok { + if _, ok := v.chanAnnFinSignal[scid]; !ok { // We'll create the channel that we close after we // validate this announcement. All dependants will // point to this same channel, so they'll be unblocked @@ -119,11 +120,11 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { deny: make(chan struct{}), } - v.chanAnnFinSignal[msg.ShortChannelID] = signals - v.chanEdgeDependencies[msg.ShortChannelID] = signals + v.chanAnnFinSignal[scid] = signals + v.chanEdgeDependencies[scid] = signals - v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = signals - v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = signals + v.nodeAnnDependencies[msg.Node1KeyBytes()] = signals + v.nodeAnnDependencies[msg.Node2KeyBytes()] = signals } case *channeldb.ChannelEdgeInfo1: @@ -218,7 +219,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { case *lnwire.AnnounceSignatures: // TODO(roasbeef): need to wait on chan ann? case *channeldb.ChannelEdgeInfo1: - case *lnwire.ChannelAnnouncement1: + case lnwire.ChannelAnnouncement: } // Release the lock once the above read is finished. @@ -274,18 +275,20 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { } delete(v.chanAnnFinSignal, shortID) } - case *lnwire.ChannelAnnouncement1: - finSignals, ok := v.chanAnnFinSignal[msg.ShortChannelID] + case lnwire.ChannelAnnouncement: + scid := msg.SCID() + + finSignals, ok := v.chanAnnFinSignal[scid] if ok { if allow { close(finSignals.allow) } else { close(finSignals.deny) } - delete(v.chanAnnFinSignal, msg.ShortChannelID) + delete(v.chanAnnFinSignal, scid) } - delete(v.chanEdgeDependencies, msg.ShortChannelID) + delete(v.chanEdgeDependencies, scid) // For all other job types, we'll delete the tracking entries from the // map, as if we reach this point, then all dependants have already From e66196a5b1cdb5f6a55343396a58beda850a9eb7 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 12:08:53 +0200 Subject: [PATCH 17/33] multi: remove kvdb.Backend from channeldb.LightningNode --- autopilot/graph.go | 11 ++++- channeldb/graph.go | 65 +++++++++++------------------ channeldb/graph_test.go | 91 ++++++++++++++++++----------------------- routing/router.go | 17 ++++---- rpcserver.go | 57 +++++++++++++------------- server.go | 2 +- 6 files changed, 110 insertions(+), 133 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index a049833250..8a277be3eb 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -53,6 +53,8 @@ func ChannelGraphFromDatabase(db *channeldb.ChannelGraph) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { + db kvdb.Backend + tx kvdb.RTx node *channeldb.LightningNode @@ -86,8 +88,9 @@ func (d dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.node.ForEachChannel(d.tx, func(tx kvdb.RTx, - ei *channeldb.ChannelEdgeInfo1, ep, _ *channeldb.ChannelEdgePolicy) error { + return d.node.ForEachChannel(d.db, d.tx, func(tx kvdb.RTx, + ei *channeldb.ChannelEdgeInfo1, ep, + _ *channeldb.ChannelEdgePolicy) error { // Skip channels for which no outgoing edge policy is available. // @@ -104,6 +107,7 @@ func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), Capacity: ei.Capacity, Peer: dbNode{ + db: d.db, tx: tx, node: ep.Node, }, @@ -128,6 +132,7 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { } node := dbNode{ + db: d.db.DB(), tx: tx, node: n, } @@ -267,6 +272,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, Capacity: capacity, Peer: dbNode{ node: vertex1, + db: d.db.DB(), }, }, &ChannelEdge{ @@ -274,6 +280,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, Capacity: capacity, Peer: dbNode{ node: vertex2, + db: d.db.DB(), }, }, nil diff --git a/channeldb/graph.go b/channeldb/graph.go index b372dacd0b..f94fb8af4f 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -249,6 +249,10 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, return g, nil } +func (c *ChannelGraph) DB() kvdb.Backend { + return c.db +} + // channelMapKey is the key structure used for storing channel edge policies. type channelMapKey struct { nodeKey route.Vertex @@ -557,7 +561,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { channels := make(map[uint64]*DirectedChannel) - err := node.ForEachChannel(tx, func(tx kvdb.RTx, + err := node.ForEachChannel(c.db, tx, func(tx kvdb.RTx, e *ChannelEdgeInfo1, p1 *ChannelEdgePolicy, p2 *ChannelEdgePolicy) error { @@ -679,7 +683,6 @@ func (c *ChannelGraph) ForEachNode( if err != nil { return err } - node.db = c.db // Execute the callback, the transaction will abort if // this returns an error. @@ -777,7 +780,6 @@ func (c *ChannelGraph) sourceNode(nodes kvdb.RBucket) (*LightningNode, error) { if err != nil { return nil, err } - node.db = c.db return &node, nil } @@ -1186,8 +1188,9 @@ func (c *ChannelGraph) HasChannelEdge( return ErrGraphNodeNotFound } - e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes, - channelID[:], c.db) + e1, e2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, channelID[:], + ) if err != nil { return err } @@ -1943,7 +1946,7 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // With the static information obtained, we'll now // fetch the dynamic policy info. edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, + edgeIndex, edges, nodes, chanID, ) if err != nil { chanID := byteOrder.Uint64(chanID) @@ -2035,7 +2038,6 @@ func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, if err != nil { return err } - node.db = c.db nodesInHorizon = append(nodesInHorizon, node) } @@ -2275,7 +2277,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { // With the static information obtained, we'll now // fetch the dynamic policy info. edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, cidBytes[:], c.db, + edgeIndex, edges, nodes, cidBytes[:], ) if err != nil { return err @@ -2356,7 +2358,7 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // times. cid := byteOrder.Uint64(chanID) edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, nil, + edgeIndex, edges, nodes, chanID, ) if err != nil { return err @@ -2657,8 +2659,6 @@ type LightningNode struct { // compatible manner. ExtraOpaqueData []byte - db kvdb.Backend - // TODO(roasbeef): discovery will need storage to keep it's last IP // address and re-announce if interface changes? @@ -2740,14 +2740,16 @@ func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, // isPublic determines whether the node is seen as public within the graph from // the source node's point of view. An existing database transaction can also be // specified. -func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error) { +func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, + sourcePubKey []byte) (bool, error) { + // In order to determine whether this node is publicly advertised within // the graph, we'll need to look at all of its edges and check whether // they extend to any other node than the source node. errDone will be // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := l.ForEachChannel(tx, func(_ kvdb.RTx, info *ChannelEdgeInfo1, + err := l.ForEachChannel(db, tx, func(_ kvdb.RTx, info *ChannelEdgeInfo1, _, _ *ChannelEdgePolicy) error { // If this edge doesn't extend to the source node, we'll @@ -2807,7 +2809,6 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( if err != nil { return err } - n.db = c.db node = &n @@ -2919,7 +2920,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -3016,12 +3018,11 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // should be passed as the first argument. Otherwise the first argument should // be nil and a fresh transaction will be created to execute the graph // traversal. -func (l *LightningNode) ForEachChannel(tx kvdb.RTx, +func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] - db := l.db return nodeTraversal(tx, nodePub, db, cb) } @@ -3236,7 +3237,6 @@ func (c *ChannelEdgeInfo1) FetchOtherNode(tx kvdb.RTx, if err != nil { return err } - node.db = c.db targetNode = &node @@ -3447,8 +3447,6 @@ type ChannelEdgePolicy struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db kvdb.Backend } // Signature is a channel announcement signature, which is needed for proper @@ -3570,7 +3568,7 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // we'll fetch the routing policies for each for the directed // edges. e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, + edgeIndex, edges, nodes, chanID, ) if err != nil { return err @@ -3675,7 +3673,7 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, // Then we'll attempt to fetch the accompanying policies of this // edge. e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, channelID[:], c.db, + edgeIndex, edges, nodes, channelID[:], ) if err != nil { return err @@ -3718,7 +3716,7 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { return err } - nodeIsPublic, err = node.isPublic(tx, ourPubKey) + nodeIsPublic, err = node.isPublic(c.db, tx, ourPubKey) return err }, func() { nodeIsPublic = false @@ -3834,11 +3832,6 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return edgePoints, nil } -// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. -func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { - return &ChannelEdgePolicy{db: c.db} -} - // MarkEdgeZombie attempts to mark a channel identified by its channel ID as a // zombie. This method is used on an ad-hoc basis, when channels need to be // marked as zombies outside the normal pruning cycle. @@ -4618,8 +4611,8 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, - nodes kvdb.RBucket, chanID []byte, - db kvdb.Backend) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { + nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicy, + *ChannelEdgePolicy, error) { edgeInfo := edgeIndex.Get(chanID) if edgeInfo == nil { @@ -4635,13 +4628,6 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return nil, nil, err } - // As we may have a single direction of the edge but not the other, - // only fill in the database pointers if the edge is found. - if edge1 != nil { - edge1.db = db - edge1.Node.db = db - } - // Similarly, the second node is contained within the latter // half of the edge information. node2Pub := edgeInfo[33:66] @@ -4650,11 +4636,6 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return nil, nil, err } - if edge2 != nil { - edge2.db = db - edge2.Node.db = db - } - return edge1, edge2, nil } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 2cdce5f2e9..fd5f24ad83 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -95,7 +95,6 @@ func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNod Alias: "kek" + string(pub[:]), Features: testFeatures, Addresses: testAddrs, - db: db, } copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) @@ -129,7 +128,6 @@ func TestNodeInsertionAndDeletion(t *testing.T) { Addresses: testAddrs, ExtraOpaqueData: []byte("extra new data"), PubKeyBytes: testPub, - db: graph.db, } // First, insert the node into the graph DB. This should succeed @@ -207,7 +205,6 @@ func TestPartialNode(t *testing.T) { HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: testPub, - db: graph.db, } if err := compareNodes(node, dbNode); err != nil { @@ -674,7 +671,6 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( FeeProportionalMillionths: 3452352, Node: secondNode, ExtraOpaqueData: []byte("new unknown feature2"), - db: db, } edge2 := &ChannelEdgePolicy{ SigBytes: testSig.Serialize(), @@ -689,7 +685,6 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( FeeProportionalMillionths: 90392423, Node: firstNode, ExtraOpaqueData: []byte("new unknown feature1"), - db: db, } return edgeInfo, edge1, edge2 @@ -995,7 +990,6 @@ func newEdgePolicy(chanID uint64, db kvdb.Backend, MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), - db: db, } } @@ -1056,30 +1050,31 @@ func TestGraphTraversal(t *testing.T) { // outgoing channels for a particular node. numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] - err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo1, - outEdge, inEdge *ChannelEdgePolicy) error { + err = firstNode.ForEachChannel(graph.DB(), nil, + func(_ kvdb.RTx, _ *ChannelEdgeInfo1, + outEdge, inEdge *ChannelEdgePolicy) error { - // All channels between first and second node should have fully - // (both sides) specified policies. - if inEdge == nil || outEdge == nil { - return fmt.Errorf("channel policy not present") - } + // All channels between first and second node should have fully + // (both sides) specified policies. + if inEdge == nil || outEdge == nil { + return fmt.Errorf("channel policy not present") + } - // Each should indicate that it's outgoing (pointed - // towards the second node). - if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } + // Each should indicate that it's outgoing (pointed + // towards the second node). + if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } - // The incoming edge should also indicate that it's pointing to - // the origin node. - if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } + // The incoming edge should also indicate that it's pointing to + // the origin node. + if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } - numNodeChans++ - return nil - }) + numNodeChans++ + return nil + }) require.NoError(t, err) require.Equal(t, numChannels, numNodeChans) } @@ -2280,29 +2275,30 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Ensure that channel is reported with unknown policies. checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 - err := node.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo1, - outEdge, inEdge *ChannelEdgePolicy) error { + err := node.ForEachChannel(graph.DB(), nil, + func(_ kvdb.RTx, _ *ChannelEdgeInfo1, + outEdge, inEdge *ChannelEdgePolicy) error { - if !expectedOut && outEdge != nil { - t.Fatalf("Expected no outgoing policy") - } + if !expectedOut && outEdge != nil { + t.Fatalf("Expected no outgoing policy") + } - if expectedOut && outEdge == nil { - t.Fatalf("Expected an outgoing policy") - } + if expectedOut && outEdge == nil { + t.Fatalf("Expected an outgoing policy") + } - if !expectedIn && inEdge != nil { - t.Fatalf("Expected no incoming policy") - } + if !expectedIn && inEdge != nil { + t.Fatalf("Expected no incoming policy") + } - if expectedIn && inEdge == nil { - t.Fatalf("Expected an incoming policy") - } + if expectedIn && inEdge == nil { + t.Fatalf("Expected an incoming policy") + } - calls++ + calls++ - return nil - }) + return nil + }) if err != nil { t.Fatalf("unable to scan channels: %v", err) } @@ -2703,7 +2699,6 @@ func TestNodeIsPublic(t *testing.T) { graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for i, graph := range graphs { for _, node := range nodes { - node.db = dbs[i] if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3140,10 +3135,6 @@ func compareNodes(a, b *LightningNode) error { return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ "got %#v", a.Alias, b.Alias) } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) @@ -3203,10 +3194,6 @@ func compareEdgePolicies(a, b *ChannelEdgePolicy) error { if err := compareNodes(a.Node, b.Node); err != nil { return err } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } return nil } diff --git a/routing/router.go b/routing/router.go index 783e2139da..b72761146e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2824,16 +2824,17 @@ func (r *ChannelRouter) ForEachNode( func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error { - return r.selfNode.ForEachChannel(nil, func(tx kvdb.RTx, - c *channeldb.ChannelEdgeInfo1, - e, _ *channeldb.ChannelEdgePolicy) error { + return r.selfNode.ForEachChannel(r.cfg.Graph.DB(), nil, + func(tx kvdb.RTx, c *channeldb.ChannelEdgeInfo1, + e, _ *channeldb.ChannelEdgePolicy) error { - if e == nil { - return fmt.Errorf("channel from self node has no policy") - } + if e == nil { + return fmt.Errorf("channel from self node " + + "has no policy") + } - return cb(tx, c, e) - }) + return cb(tx, c, e) + }) } // AddProof updates the channel edge info with proof which is needed to diff --git a/rpcserver.go b/rpcserver.go index bb4f2f4b53..39d8507f7d 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6151,7 +6151,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - if err := node.ForEachChannel(nil, func(_ kvdb.RTx, + if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.RTx, edge *channeldb.ChannelEdgeInfo1, c1, c2 *channeldb.ChannelEdgePolicy) error { @@ -6763,34 +6763,35 @@ func (r *rpcServer) FeeReport(ctx context.Context, } var feeReports []*lnrpc.ChannelFeeReport - err = selfNode.ForEachChannel(nil, func(_ kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, - edgePolicy, _ *channeldb.ChannelEdgePolicy) error { - - // Self node should always have policies for its channels. - if edgePolicy == nil { - return fmt.Errorf("no policy for outgoing channel %v ", - chanInfo.ChannelID) - } - - // We'll compute the effective fee rate by converting from a - // fixed point fee rate to a floating point fee rate. The fee - // rate field in the database the amount of mSAT charged per - // 1mil mSAT sent, so will divide by this to get the proper fee - // rate. - feeRateFixedPoint := edgePolicy.FeeProportionalMillionths - feeRate := float64(feeRateFixedPoint) / feeBase - - // TODO(roasbeef): also add stats for revenue for each channel - feeReports = append(feeReports, &lnrpc.ChannelFeeReport{ - ChanId: chanInfo.ChannelID, - ChannelPoint: chanInfo.ChannelPoint.String(), - BaseFeeMsat: int64(edgePolicy.FeeBaseMSat), - FeePerMil: int64(feeRateFixedPoint), - FeeRate: feeRate, - }) + err = selfNode.ForEachChannel(channelGraph.DB(), nil, + func(_ kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, + edgePolicy, _ *channeldb.ChannelEdgePolicy) error { + + // Self node should always have policies for its channels. + if edgePolicy == nil { + return fmt.Errorf("no policy for outgoing channel %v ", + chanInfo.ChannelID) + } - return nil - }) + // We'll compute the effective fee rate by converting from a + // fixed point fee rate to a floating point fee rate. The fee + // rate field in the database the amount of mSAT charged per + // 1mil mSAT sent, so will divide by this to get the proper fee + // rate. + feeRateFixedPoint := edgePolicy.FeeProportionalMillionths + feeRate := float64(feeRateFixedPoint) / feeBase + + // TODO(roasbeef): also add stats for revenue for each channel + feeReports = append(feeReports, &lnrpc.ChannelFeeReport{ + ChanId: chanInfo.ChannelID, + ChannelPoint: chanInfo.ChannelPoint.String(), + BaseFeeMsat: int64(edgePolicy.FeeBaseMSat), + FeePerMil: int64(feeRateFixedPoint), + FeeRate: feeRate, + }) + + return nil + }) if err != nil { return nil, err } diff --git a/server.go b/server.go index 4ff46499d4..4aefe24df4 100644 --- a/server.go +++ b/server.go @@ -3094,7 +3094,7 @@ func (s *server) establishPersistentConnections() error { // TODO(roasbeef): instead iterate over link nodes and query graph for // each of the nodes. selfPub := s.identityECDH.PubKey().SerializeCompressed() - err = sourceNode.ForEachChannel(nil, func( + err = sourceNode.ForEachChannel(s.graphDB.DB(), nil, func( tx kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, policy, _ *channeldb.ChannelEdgePolicy) error { From c461d7c9959246a76a35f5cfaaf6d7d6793b5087 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 12:28:14 +0200 Subject: [PATCH 18/33] multi: remove kvdb.Backend from ChannelEdgeInfo --- autopilot/graph.go | 6 +++--- channeldb/graph.go | 38 ++++++++++++++++------------------- channeldb/graph_cache.go | 20 +++++++++--------- channeldb/graph_cache_test.go | 6 +++--- channeldb/graph_test.go | 22 ++++++++++---------- routing/router.go | 2 +- rpcserver.go | 7 ++++--- server.go | 4 ++-- 8 files changed, 50 insertions(+), 55 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 8a277be3eb..9fa14da27e 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -88,8 +88,8 @@ func (d dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.node.ForEachChannel(d.db, d.tx, func(tx kvdb.RTx, - ei *channeldb.ChannelEdgeInfo1, ep, + return d.node.ForEachChannel(d.db, d.tx, func(db kvdb.Backend, + tx kvdb.RTx, ei *channeldb.ChannelEdgeInfo1, ep, _ *channeldb.ChannelEdgePolicy) error { // Skip channels for which no outgoing edge policy is available. @@ -107,7 +107,7 @@ func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), Capacity: ei.Capacity, Peer: dbNode{ - db: d.db, + db: db, tx: tx, node: ep.Node, }, diff --git a/channeldb/graph.go b/channeldb/graph.go index f94fb8af4f..80e9451623 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -487,8 +487,8 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, return err } - dbCallback := func(tx kvdb.RTx, e *ChannelEdgeInfo1, p1, - p2 *ChannelEdgePolicy) error { + dbCallback := func(_ kvdb.Backend, tx kvdb.RTx, e *ChannelEdgeInfo1, + p1, p2 *ChannelEdgePolicy) error { var cachedInPolicy *CachedEdgePolicy if p2 != nil { @@ -561,8 +561,8 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { channels := make(map[uint64]*DirectedChannel) - err := node.ForEachChannel(c.db, tx, func(tx kvdb.RTx, - e *ChannelEdgeInfo1, p1 *ChannelEdgePolicy, + err := node.ForEachChannel(c.db, tx, func(_ kvdb.Backend, + tx kvdb.RTx, e *ChannelEdgeInfo1, p1 *ChannelEdgePolicy, p2 *ChannelEdgePolicy) error { toNodeCallback := func() route.Vertex { @@ -1941,7 +1941,6 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, return fmt.Errorf("unable to fetch info for "+ "edge with chan_id=%v: %v", chanID, err) } - edgeInfo.db = c.db // With the static information obtained, we'll now // fetch the dynamic policy info. @@ -2272,7 +2271,6 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { case err != nil: return err } - edgeInfo.db = c.db // With the static information obtained, we'll now // fetch the dynamic policy info. @@ -2749,8 +2747,8 @@ func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := l.ForEachChannel(db, tx, func(_ kvdb.RTx, info *ChannelEdgeInfo1, - _, _ *ChannelEdgePolicy) error { + err := l.ForEachChannel(db, tx, func(_ kvdb.Backend, _ kvdb.RTx, + info *ChannelEdgeInfo1, _, _ *ChannelEdgePolicy) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is @@ -2858,11 +2856,11 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, +func (n *graphCacheNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, + cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) + return nodeTraversal(tx, n.pubKeyBytes[:], db, cb) } var _ GraphCacheNode = (*graphCacheNode)(nil) @@ -2920,7 +2918,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, + cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { @@ -2963,7 +2961,6 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, if err != nil { return err } - edgeInfo.db = db outgoingPolicy, err := fetchChanEdgePolicy( edges, chanID, nodePub, nodes, @@ -2985,7 +2982,10 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, } // Finally, we execute the callback. - err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) + err = cb( + db, tx, &edgeInfo, outgoingPolicy, + incomingPolicy, + ) if err != nil { return err } @@ -3019,7 +3019,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, + cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] @@ -3086,8 +3086,6 @@ type ChannelEdgeInfo1 struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db kvdb.Backend } // AddNodeKeys is a setter-like method that can be used to replace the set of @@ -3210,7 +3208,7 @@ func (c *ChannelEdgeInfo1) OtherNodeKeyBytes(thisNodeKey []byte) ( // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. -func (c *ChannelEdgeInfo1) FetchOtherNode(tx kvdb.RTx, +func (c *ChannelEdgeInfo1) FetchOtherNode(db kvdb.Backend, tx kvdb.RTx, thisNodeKey []byte) (*LightningNode, error) { // Ensure that the node passed in is actually a member of the channel. @@ -3247,7 +3245,7 @@ func (c *ChannelEdgeInfo1) FetchOtherNode(tx kvdb.RTx, // otherwise we can use the existing db transaction. var err error if tx == nil { - err = kvdb.View(c.db, fetchNodeFunc, func() { targetNode = nil }) + err = kvdb.View(db, fetchNodeFunc, func() { targetNode = nil }) } else { err = fetchNodeFunc(tx) } @@ -3562,7 +3560,6 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, return err } edgeInfo = &edge - edgeInfo.db = c.db // Once we have the information about the channels' parameters, // we'll fetch the routing policies for each for the directed @@ -3668,7 +3665,6 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, } edgeInfo = &edge - edgeInfo.db = c.db // Then we'll attempt to fetch the accompanying policies of this // edge. diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 6cb67431fc..944f4825b6 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -26,9 +26,9 @@ type GraphCacheNode interface { // incoming edge *from* the connecting node. If the callback returns an // error, then the iteration is halted with the error propagated back up // to the caller. - ForEachChannel(kvdb.RTx, - func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error + ForEachChannel(kvdb.Backend, kvdb.RTx, + func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error } // CachedEdgePolicy is a struct that only caches the information of a @@ -222,16 +222,14 @@ func (c *GraphCache) AddNodeFeatures(node GraphCacheNode) { func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) - return node.ForEachChannel( - tx, func(tx kvdb.RTx, info *ChannelEdgeInfo1, - outPolicy *ChannelEdgePolicy, - inPolicy *ChannelEdgePolicy) error { + return node.ForEachChannel(nil, tx, func(_ kvdb.Backend, tx kvdb.RTx, + info *ChannelEdgeInfo1, outPolicy *ChannelEdgePolicy, + inPolicy *ChannelEdgePolicy) error { - c.AddChannel(info, outPolicy, inPolicy) + c.AddChannel(info, outPolicy, inPolicy) - return nil - }, - ) + return nil + }) } // AddChannel adds a non-directed channel, meaning that the order of policy 1 diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 2d0faddfc5..9c5671be26 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -40,13 +40,13 @@ func (n *node) Features() *lnwire.FeatureVector { return n.features } -func (n *node) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, +func (n *node) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, + cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { for idx := range n.edgeInfos { err := cb( - tx, n.edgeInfos[idx], n.outPolicies[idx], + db, tx, n.edgeInfos[idx], n.outPolicies[idx], n.inPolicies[idx], ) if err != nil { diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index fd5f24ad83..c37495003f 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1051,7 +1051,7 @@ func TestGraphTraversal(t *testing.T) { numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(graph.DB(), nil, - func(_ kvdb.RTx, _ *ChannelEdgeInfo1, + func(_ kvdb.Backend, _ kvdb.RTx, _ *ChannelEdgeInfo1, outEdge, inEdge *ChannelEdgePolicy) error { // All channels between first and second node should have fully @@ -1124,11 +1124,13 @@ func TestGraphTraversalCacheable(t *testing.T) { err = graph.db.View(func(tx kvdb.RTx) error { for _, node := range nodes { err := node.ForEachChannel( - tx, func(tx kvdb.RTx, info *ChannelEdgeInfo1, - policy *ChannelEdgePolicy, - policy2 *ChannelEdgePolicy) error { + graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, + info *ChannelEdgeInfo1, + _ *ChannelEdgePolicy, + _ *ChannelEdgePolicy) error { delete(chanIndex, info.ChannelID) + return nil }, ) @@ -2276,7 +2278,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 err := node.ForEachChannel(graph.DB(), nil, - func(_ kvdb.RTx, _ *ChannelEdgeInfo1, + func(_ kvdb.Backend, _ kvdb.RTx, _ *ChannelEdgeInfo1, outEdge, inEdge *ChannelEdgePolicy) error { if !expectedOut && outEdge != nil { @@ -2695,16 +2697,14 @@ func TestNodeIsPublic(t *testing.T) { // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} edges := []*ChannelEdgeInfo1{&aliceBobEdge, &bobCarolEdge} - dbs := []kvdb.Backend{aliceGraph.db, bobGraph.db, carolGraph.db} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} - for i, graph := range graphs { + for _, graph := range graphs { for _, node := range nodes { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } } for _, edge := range edges { - edge.db = dbs[i] if err := graph.AddChannelEdge(edge); err != nil { t.Fatalf("unable to add edge: %v", err) } @@ -2764,7 +2764,7 @@ func TestNodeIsPublic(t *testing.T) { // that allows it to be advertised. Within Alice's graph, we'll // completely remove the edge as it is not possible for her to know of // it without it being advertised. - for i, graph := range graphs { + for _, graph := range graphs { err := graph.DeleteChannelEdges( false, true, bobCarolEdge.ChannelID, ) @@ -2777,7 +2777,6 @@ func TestNodeIsPublic(t *testing.T) { } bobCarolEdge.AuthProof = nil - bobCarolEdge.db = dbs[i] if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { t.Fatalf("unable to add edge: %v", err) } @@ -3429,7 +3428,8 @@ func BenchmarkForEachChannel(b *testing.B) { err = graph.db.View(func(tx kvdb.RTx) error { for _, n := range nodes { err := n.ForEachChannel( - tx, func(tx kvdb.RTx, + graph.db, tx, func(_ kvdb.Backend, + _ kvdb.RTx, info *ChannelEdgeInfo1, policy *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) error { diff --git a/routing/router.go b/routing/router.go index b72761146e..6fb07c749b 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2825,7 +2825,7 @@ func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error { return r.selfNode.ForEachChannel(r.cfg.Graph.DB(), nil, - func(tx kvdb.RTx, c *channeldb.ChannelEdgeInfo1, + func(_ kvdb.Backend, tx kvdb.RTx, c *channeldb.ChannelEdgeInfo1, e, _ *channeldb.ChannelEdgePolicy) error { if e == nil { diff --git a/rpcserver.go b/rpcserver.go index 39d8507f7d..3aed256fc5 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6151,8 +6151,8 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.RTx, - edge *channeldb.ChannelEdgeInfo1, + if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, + _ kvdb.RTx, edge *channeldb.ChannelEdgeInfo1, c1, c2 *channeldb.ChannelEdgePolicy) error { numChannels++ @@ -6764,7 +6764,8 @@ func (r *rpcServer) FeeReport(ctx context.Context, var feeReports []*lnrpc.ChannelFeeReport err = selfNode.ForEachChannel(channelGraph.DB(), nil, - func(_ kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, + func(_ kvdb.Backend, _ kvdb.RTx, + chanInfo *channeldb.ChannelEdgeInfo1, edgePolicy, _ *channeldb.ChannelEdgePolicy) error { // Self node should always have policies for its channels. diff --git a/server.go b/server.go index 4aefe24df4..7ac07cdab8 100644 --- a/server.go +++ b/server.go @@ -3095,7 +3095,7 @@ func (s *server) establishPersistentConnections() error { // each of the nodes. selfPub := s.identityECDH.PubKey().SerializeCompressed() err = sourceNode.ForEachChannel(s.graphDB.DB(), nil, func( - tx kvdb.RTx, + db kvdb.Backend, tx kvdb.RTx, chanInfo *channeldb.ChannelEdgeInfo1, policy, _ *channeldb.ChannelEdgePolicy) error { @@ -3109,7 +3109,7 @@ func (s *server) establishPersistentConnections() error { // We'll now fetch the peer opposite from us within this // channel so we can queue up a direct connection to them. - channelPeer, err := chanInfo.FetchOtherNode(tx, selfPub) + channelPeer, err := chanInfo.FetchOtherNode(db, tx, selfPub) if err != nil { return fmt.Errorf("unable to fetch channel peer for "+ "ChannelPoint(%v): %v", chanInfo.ChannelPoint, From 711f41d19ce7cec1d393cade15b858744345401b Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 13:07:21 +0200 Subject: [PATCH 19/33] channeldb: add ChannelEdgeInfo interface --- channeldb/graph.go | 171 +++++++++++++++++++++++++++++++++ channeldb/models/interfaces.go | 52 ++++++++++ 2 files changed, 223 insertions(+) create mode 100644 channeldb/models/interfaces.go diff --git a/channeldb/graph.go b/channeldb/graph.go index 80e9451623..86a5b5e0fb 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -3088,6 +3089,164 @@ type ChannelEdgeInfo1 struct { ExtraOpaqueData []byte } +// Copy returns a copy of the ChannelEdgeInfo. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Copy() models.ChannelEdgeInfo { + return &ChannelEdgeInfo1{ + ChannelID: c.ChannelID, + ChainHash: c.ChainHash, + NodeKey1Bytes: c.NodeKey1Bytes, + NodeKey2Bytes: c.NodeKey2Bytes, + BitcoinKey1Bytes: c.BitcoinKey1Bytes, + BitcoinKey2Bytes: c.BitcoinKey2Bytes, + Features: c.Features, + AuthProof: c.AuthProof, + ChannelPoint: c.ChannelPoint, + Capacity: c.Capacity, + ExtraOpaqueData: c.ExtraOpaqueData, + } +} + +// Node1Bytes returns bytes of the public key of node 1. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Node1Bytes() [33]byte { + return c.NodeKey1Bytes +} + +// Node2Bytes returns bytes of the public key of node 2. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Node2Bytes() [33]byte { + return c.NodeKey2Bytes +} + +// GetChainHash returns the hash of the genesis block of the chain that the edge +// is on. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChainHash() chainhash.Hash { + return c.ChainHash +} + +// GetChanID returns the channel ID. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChanID() uint64 { + return c.ChannelID +} + +// GetAuthProof returns the ChannelAuthProof for the edge. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetAuthProof() models.ChannelAuthProof { + // Cant just return AuthProof cause you then run into the + // nil interface gotcha. + if c.AuthProof == nil { + return nil + } + + return c.AuthProof +} + +// GetCapacity returns the capacity of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetCapacity() btcutil.Amount { + return c.Capacity +} + +// SetAuthProof sets the proof of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) SetAuthProof(proof models.ChannelAuthProof) error { + if proof == nil { + c.AuthProof = nil + + return nil + } + + p, ok := proof.(*ChannelAuthProof1) + if !ok { + return fmt.Errorf("expected type ChannelAuthProof1 for "+ + "ChannelEdgeInfo1, got %T", proof) + } + + c.AuthProof = p + + return nil +} + +// GetChanPoint returns the outpoint of the funding transaction of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChanPoint() wire.OutPoint { + return c.ChannelPoint +} + +// FundingScript returns the pk script for the funding output of the +// channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) FundingScript() ([]byte, error) { + legacyFundingScript := func() ([]byte, error) { + witnessScript, err := input.GenMultiSigScript( + c.BitcoinKey1Bytes[:], c.BitcoinKey2Bytes[:], + ) + if err != nil { + return nil, err + } + pkScript, err := input.WitnessScriptHash(witnessScript) + if err != nil { + return nil, err + } + + return pkScript, nil + } + + if len(c.Features) == 0 { + return legacyFundingScript() + } + + // In order to make the correct funding script, we'll need to parse the + // chanFeatures bytes into a feature vector we can interact with. + rawFeatures := lnwire.NewRawFeatureVector() + err := rawFeatures.Decode(bytes.NewReader(c.Features)) + if err != nil { + return nil, fmt.Errorf("unable to parse chan feature "+ + "bits: %w", err) + } + + chanFeatureBits := lnwire.NewFeatureVector( + rawFeatures, lnwire.Features, + ) + if chanFeatureBits.HasFeature( + lnwire.SimpleTaprootChannelsOptionalStaging, + ) { + + pubKey1, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:]) + if err != nil { + return nil, err + } + pubKey2, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:]) + if err != nil { + return nil, err + } + + fundingScript, _, err := input.GenTaprootFundingScript( + pubKey1, pubKey2, 0, + ) + if err != nil { + return nil, err + } + + return fundingScript, nil + } + + return legacyFundingScript() +} + // AddNodeKeys is a setter-like method that can be used to replace the set of // keys for the target ChannelEdgeInfo1. func (c *ChannelEdgeInfo1) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, @@ -3113,6 +3272,8 @@ func (c *ChannelEdgeInfo1) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. func (c *ChannelEdgeInfo1) NodeKey1() (*btcec.PublicKey, error) { if c.nodeKey1 != nil { return c.nodeKey1, nil @@ -3135,6 +3296,8 @@ func (c *ChannelEdgeInfo1) NodeKey1() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. func (c *ChannelEdgeInfo1) NodeKey2() (*btcec.PublicKey, error) { if c.nodeKey2 != nil { return c.nodeKey2, nil @@ -3253,6 +3416,10 @@ func (c *ChannelEdgeInfo1) FetchOtherNode(db kvdb.Backend, tx kvdb.RTx, return targetNode, err } +// A compile-time check to ensure that ChannelEdgeInfo1 implements +// modesl.ChannelEdgeInfo. +var _ models.ChannelEdgeInfo = (*ChannelEdgeInfo1)(nil) + // ChannelAuthProof1 is the authentication proof (the signature portion) for a // channel. Using the four signatures contained in the struct, and some // auxiliary knowledge (the funding script, node identities, and outpoint) nodes @@ -3290,6 +3457,10 @@ type ChannelAuthProof1 struct { BitcoinSig2Bytes []byte } +// A compile-time check to ensure that ChannelAuthProof1 implements the +// models.ChannelAuthProof interface. +var _ models.ChannelAuthProof = (*ChannelAuthProof1)(nil) + // Node1Sig is the signature using the identity key of the node that is first // in a lexicographical ordering of the serialized public keys of the two nodes // that created the channel. diff --git a/channeldb/models/interfaces.go b/channeldb/models/interfaces.go new file mode 100644 index 0000000000..4a3d82fc98 --- /dev/null +++ b/channeldb/models/interfaces.go @@ -0,0 +1,52 @@ +package models + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +type ChannelEdgeInfo interface { //nolint:interfacebloat + // GetChainHash returns the hash of the genesis block of the chain that + // the edge is on. + GetChainHash() chainhash.Hash + + // GetChanID returns the channel ID. + GetChanID() uint64 + + // GetAuthProof returns the ChannelAuthProof for the edge. + GetAuthProof() ChannelAuthProof + + // GetCapacity returns the capacity of the channel. + GetCapacity() btcutil.Amount + + // SetAuthProof sets the proof of the channel. + SetAuthProof(ChannelAuthProof) error + + // NodeKey1 returns the public key of node 1. + NodeKey1() (*btcec.PublicKey, error) + + // NodeKey2 returns the public key of node 2. + NodeKey2() (*btcec.PublicKey, error) + + // Node1Bytes returns bytes of the public key of node 1. + Node1Bytes() [33]byte + + // Node2Bytes returns bytes the public key of node 2. + Node2Bytes() [33]byte + + // GetChanPoint returns the outpoint of the funding transaction of the + // channel. + GetChanPoint() wire.OutPoint + + // FundingScript returns the pk script for the funding output of the + // channel. + FundingScript() ([]byte, error) + + // Copy returns a copy of the ChannelEdgeInfo. + Copy() ChannelEdgeInfo +} + +type ChannelAuthProof interface { +} From 53ffc11230d592a3e24c25a090e31f4fd5fbadfe Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 13:17:38 +0200 Subject: [PATCH 20/33] multi: use the ChannelEdgeInfo interface everywhere --- autopilot/graph.go | 5 +- channeldb/graph.go | 412 +++++++++++++++------------ channeldb/graph_cache.go | 56 ++-- channeldb/graph_cache_test.go | 5 +- channeldb/graph_test.go | 129 +++++---- discovery/chan_series.go | 12 +- discovery/gossiper.go | 249 +++++++++------- discovery/gossiper_test.go | 77 ++--- lnrpc/invoicesrpc/addinvoice.go | 10 +- lnrpc/invoicesrpc/addinvoice_test.go | 10 +- netann/chan_status_manager.go | 3 +- netann/chan_status_manager_test.go | 3 +- netann/channel_announcement.go | 26 +- netann/channel_update.go | 12 +- netann/interface.go | 3 +- peer/brontide.go | 13 +- routing/localchans/manager.go | 16 +- routing/localchans/manager_test.go | 2 +- routing/notifications.go | 15 +- routing/router.go | 177 +++++------- routing/router_test.go | 4 +- routing/validation_barrier.go | 16 +- rpcserver.go | 113 +++++--- server.go | 13 +- 24 files changed, 765 insertions(+), 616 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 9fa14da27e..e30639fea8 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -89,7 +90,7 @@ func (d dbNode) Addrs() []net.Addr { // NOTE: Part of the autopilot.Node interface. func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.node.ForEachChannel(d.db, d.tx, func(db kvdb.Backend, - tx kvdb.RTx, ei *channeldb.ChannelEdgeInfo1, ep, + tx kvdb.RTx, ei models.ChannelEdgeInfo, ep, _ *channeldb.ChannelEdgePolicy) error { // Skip channels for which no outgoing edge policy is available. @@ -105,7 +106,7 @@ func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { edge := ChannelEdge{ ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), - Capacity: ei.Capacity, + Capacity: ei.GetCapacity(), Peer: dbNode{ db: db, tx: tx, diff --git a/channeldb/graph.go b/channeldb/graph.go index 86a5b5e0fb..f9e749ae86 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2,7 +2,6 @@ package channeldb import ( "bytes" - "crypto/sha256" "encoding/binary" "errors" "fmt" @@ -18,7 +17,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/batch" @@ -232,7 +230,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, return nil, err } - err = g.ForEachChannel(func(info *ChannelEdgeInfo1, + err = g.ForEachChannel(func(info models.ChannelEdgeInfo, policy1, policy2 *ChannelEdgePolicy) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -416,7 +414,7 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo1, +func (c *ChannelGraph) ForEachChannel(cb func(models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { return c.db.View(func(tx kvdb.RTx) error { @@ -450,16 +448,16 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo1, } policy1 := channelMap[channelMapKey{ - nodeKey: info.NodeKey1Bytes, + nodeKey: info.Node1Bytes(), chanID: chanID, }] policy2 := channelMap[channelMapKey{ - nodeKey: info.NodeKey2Bytes, + nodeKey: info.Node2Bytes(), chanID: chanID, }] - return cb(&info, policy1, policy2) + return cb(info, policy1, policy2) }) }, func() {}) } @@ -488,8 +486,8 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, return err } - dbCallback := func(_ kvdb.Backend, tx kvdb.RTx, e *ChannelEdgeInfo1, - p1, p2 *ChannelEdgePolicy) error { + dbCallback := func(_ kvdb.Backend, tx kvdb.RTx, + e models.ChannelEdgeInfo, p1, p2 *ChannelEdgePolicy) error { var cachedInPolicy *CachedEdgePolicy if p2 != nil { @@ -499,16 +497,16 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, } directedChannel := &DirectedChannel{ - ChannelID: e.ChannelID, - IsNode1: node == e.NodeKey1Bytes, - OtherNode: e.NodeKey2Bytes, - Capacity: e.Capacity, + ChannelID: e.GetChanID(), + IsNode1: node == e.Node1Bytes(), + OtherNode: e.Node2Bytes(), + Capacity: e.GetCapacity(), OutPolicySet: p1 != nil, InPolicy: cachedInPolicy, } - if node == e.NodeKey2Bytes { - directedChannel.OtherNode = e.NodeKey1Bytes + if node == e.Node2Bytes() { + directedChannel.OtherNode = e.Node1Bytes() } return cb(directedChannel) @@ -563,8 +561,8 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, channels := make(map[uint64]*DirectedChannel) err := node.ForEachChannel(c.db, tx, func(_ kvdb.Backend, - tx kvdb.RTx, e *ChannelEdgeInfo1, p1 *ChannelEdgePolicy, - p2 *ChannelEdgePolicy) error { + tx kvdb.RTx, e models.ChannelEdgeInfo, + p1 *ChannelEdgePolicy, p2 *ChannelEdgePolicy) error { toNodeCallback := func() route.Vertex { return node.PubKeyBytes @@ -584,19 +582,20 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, } directedChannel := &DirectedChannel{ - ChannelID: e.ChannelID, - IsNode1: node.PubKeyBytes == e.NodeKey1Bytes, - OtherNode: e.NodeKey2Bytes, - Capacity: e.Capacity, + ChannelID: e.GetChanID(), + IsNode1: node.PubKeyBytes == + e.Node1Bytes(), + OtherNode: e.Node2Bytes(), + Capacity: e.GetCapacity(), OutPolicySet: p1 != nil, InPolicy: cachedInPolicy, } - if node.PubKeyBytes == e.NodeKey2Bytes { - directedChannel.OtherNode = e.NodeKey1Bytes + if node.PubKeyBytes == e.Node2Bytes() { + directedChannel.OtherNode = e.Node1Bytes() } - channels[e.ChannelID] = directedChannel + channels[e.GetChanID()] = directedChannel return nil }) @@ -970,7 +969,7 @@ func (c *ChannelGraph) deleteLightningNode(nodes kvdb.RwBucket, // involved in creation of the channel, and the set of features that the channel // supports. The chanPoint and chanID are used to uniquely identify the edge // globally within the database. -func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo1, +func (c *ChannelGraph) AddChannelEdge(edge models.ChannelEdgeInfo, op ...batch.SchedulerOption) error { var alreadyExists bool @@ -997,8 +996,8 @@ func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo1, case alreadyExists: return ErrEdgeAlreadyExist default: - c.rejectCache.remove(edge.ChannelID) - c.chanCache.remove(edge.ChannelID) + c.rejectCache.remove(edge.GetChanID()) + c.chanCache.remove(edge.GetChanID()) return nil } }, @@ -1013,10 +1012,19 @@ func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo1, // addChannelEdge is the private form of AddChannelEdge that allows callers to // utilize an existing db transaction. -func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo1) error { +func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, + edge models.ChannelEdgeInfo) error { + + var ( + chanID = edge.GetChanID() + node1Bytes = edge.Node1Bytes() + node2Bytes = edge.Node2Bytes() + chanPoint = edge.GetChanPoint() + ) + // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + binary.BigEndian.PutUint64(chanKey[:], edge.GetChanID()) nodes, err := tx.CreateTopLevelBucket(nodeBucket) if err != nil { @@ -1049,33 +1057,33 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo1) erro // both nodes already exist in the channel graph. If either node // doesn't, then we'll insert a "shell" node that just includes its // public key, so subsequent validation and queries can work properly. - _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) + _, node1Err := fetchLightningNode(nodes, node1Bytes[:]) switch { case node1Err == ErrGraphNodeNotFound: node1Shell := LightningNode{ - PubKeyBytes: edge.NodeKey1Bytes, + PubKeyBytes: node1Bytes, HaveNodeAnnouncement: false, } err := addLightningNode(tx, &node1Shell) if err != nil { return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey1Bytes) + "for: %x", node1Bytes) } case node1Err != nil: return err } - _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) + _, node2Err := fetchLightningNode(nodes, node2Bytes[:]) switch { case node2Err == ErrGraphNodeNotFound: node2Shell := LightningNode{ - PubKeyBytes: edge.NodeKey2Bytes, + PubKeyBytes: node2Bytes, HaveNodeAnnouncement: false, } err := addLightningNode(tx, &node2Shell) if err != nil { return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey2Bytes) + "for: %x", node2Bytes) } case node2Err != nil: return err @@ -1091,11 +1099,11 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo1) erro // Mark edge policies for both sides as unknown. This is to enable // efficient incoming channel lookup for a node. keys := []*[33]byte{ - &edge.NodeKey1Bytes, - &edge.NodeKey2Bytes, + &node1Bytes, + &node2Bytes, } for _, key := range keys { - err := putChanEdgePolicyUnknown(edges, edge.ChannelID, key[:]) + err := putChanEdgePolicyUnknown(edges, chanID, key[:]) if err != nil { return err } @@ -1104,7 +1112,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo1) erro // Finally we add it to the channel index which maps channel points // (outpoints) to the shorter channel ID's. var b bytes.Buffer - if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { + if err := writeOutpoint(&b, &chanPoint); err != nil { return err } return chanIndex.Put(b.Bytes(), chanKey[:]) @@ -1224,10 +1232,10 @@ func (c *ChannelGraph) HasChannelEdge( // In order to maintain this constraints, we return an error in the scenario // that an edge info hasn't yet been created yet, but someone attempts to update // it. -func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo1) error { +func (c *ChannelGraph) UpdateChannelEdge(edge models.ChannelEdgeInfo) error { // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + binary.BigEndian.PutUint64(chanKey[:], edge.GetChanID()) return kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) @@ -1269,12 +1277,13 @@ const ( // with the current UTXO state. A slice of channels that have been closed by // the target block are returned if the function succeeds without error. func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, - blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo1, error) { + blockHash *chainhash.Hash, blockHeight uint32) ( + []models.ChannelEdgeInfo, error) { c.cacheMu.Lock() defer c.cacheMu.Unlock() - var chansClosed []*ChannelEdgeInfo1 + var chansClosed []models.ChannelEdgeInfo err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { // First grab the edges bucket which houses the information @@ -1341,7 +1350,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - chansClosed = append(chansClosed, &edgeInfo) + chansClosed = append(chansClosed, edgeInfo) } metaBucket, err := tx.CreateTopLevelBucket(graphMetaBucket) @@ -1380,8 +1389,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, } for _, channel := range chansClosed { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) + c.rejectCache.remove(channel.GetChanID()) + c.chanCache.remove(channel.GetChanID()) } if c.graphCache != nil { @@ -1460,17 +1469,17 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // edge info. We'll use this scan to populate our reference count map // above. err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - // The first 66 bytes of the edge info contain the pubkeys of - // the nodes that this edge attaches. We'll extract them, and - // add them to the ref count map. - var node1, node2 [33]byte - copy(node1[:], edgeInfoBytes[:33]) - copy(node2[:], edgeInfoBytes[33:]) + edge, err := deserializeChanEdgeInfo( + bytes.NewReader(edgeInfoBytes), + ) + if err != nil { + return err + } // With the nodes extracted, we'll increase the ref count of // each of the nodes. - nodeRefCounts[node1]++ - nodeRefCounts[node2]++ + nodeRefCounts[edge.Node1Bytes()]++ + nodeRefCounts[edge.Node2Bytes()]++ return nil }) @@ -1522,8 +1531,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // set to the last prune height valid for the remaining chain. // Channels that were removed from the graph resulting from the // disconnected block are returned. -func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo1, - error) { +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( + []models.ChannelEdgeInfo, error) { // Every channel having a ShortChannelID starting at 'height' // will no longer be confirmed. @@ -1545,7 +1554,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf defer c.cacheMu.Unlock() // Keep track of the channels that are removed from the graph. - var removedChans []*ChannelEdgeInfo1 + var removedChans []models.ChannelEdgeInfo if err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges, err := tx.CreateTopLevelBucket(edgeBucket) @@ -1588,7 +1597,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf } keys = append(keys, k) - removedChans = append(removedChans, &edgeInfo) + removedChans = append(removedChans, edgeInfo) } for _, k := range keys { @@ -1643,8 +1652,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf } for _, channel := range removedChans { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) + c.rejectCache.remove(channel.GetChanID()) + c.chanCache.remove(channel.GetChanID()) } return removedChans, nil @@ -1852,7 +1861,7 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) { // edge as well as each of the known advertised edge policies. type ChannelEdge struct { // Info contains all the static information describing the channel. - Info *ChannelEdgeInfo1 + Info models.ChannelEdgeInfo // Policy1 points to the "first" edge policy of the channel containing // the dynamic information required to properly route through the edge. @@ -1959,7 +1968,7 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // edges to be returned. edgesSeen[chanIDInt] = struct{}{} channel := ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, } @@ -2182,7 +2191,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - if edgeInfo.AuthProof == nil { + if edgeInfo.GetAuthProof() == nil { continue } @@ -2283,7 +2292,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { } chanEdges = append(chanEdges, ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, }) @@ -2345,10 +2354,15 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return err } + var ( + node1Bytes = edgeInfo.Node1Bytes() + node2Bytes = edgeInfo.Node2Bytes() + chanPoint = edgeInfo.GetChanPoint() + ) + if c.graphCache != nil { c.graphCache.RemoveChannel( - edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, - edgeInfo.ChannelID, + node1Bytes, node2Bytes, edgeInfo.GetChanID(), ) } @@ -2375,13 +2389,13 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // With the latter half constructed, copy over the first public key to // delete the edge in this direction, then the second to delete the // edge in the opposite direction. - copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:]) + copy(edgeKey[:33], node1Bytes[:]) if edges.Get(edgeKey[:]) != nil { if err := edges.Delete(edgeKey[:]); err != nil { return err } } - copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:]) + copy(edgeKey[:33], node2Bytes[:]) if edges.Get(edgeKey[:]) != nil { if err := edges.Delete(edgeKey[:]); err != nil { return err @@ -2399,7 +2413,7 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return err } var b bytes.Buffer - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := writeOutpoint(&b, &chanPoint); err != nil { return err } if err := chanIndex.Delete(b.Bytes()); err != nil { @@ -2413,9 +2427,9 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return nil } - nodeKey1, nodeKey2 := edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes + nodeKey1, nodeKey2 := node1Bytes, node2Bytes if strictZombie { - nodeKey1, nodeKey2 = makeZombiePubkeys(&edgeInfo, edge1, edge2) + nodeKey1, nodeKey2 = makeZombiePubkeys(edgeInfo, edge1, edge2) } return markEdgeZombie( @@ -2439,27 +2453,32 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // the channel. If the channel were to be marked zombie again, it would be // marked with the correct lagging channel since we received an update from only // one side. -func makeZombiePubkeys(info *ChannelEdgeInfo1, +func makeZombiePubkeys(info models.ChannelEdgeInfo, e1, e2 *ChannelEdgePolicy) ([33]byte, [33]byte) { + var ( + node1Bytes = info.Node1Bytes() + node2Bytes = info.Node2Bytes() + ) + switch { // If we don't have either edge policy, we'll return both pubkeys so // that the channel can be resurrected by either party. case e1 == nil && e2 == nil: - return info.NodeKey1Bytes, info.NodeKey2Bytes + return node1Bytes, node2Bytes // If we're missing edge1, or if both edges are present but edge1 is // older, we'll return edge1's pubkey and a blank pubkey for edge2. This // means that only an update from edge1 will be able to resurrect the // channel. case e1 == nil || (e2 != nil && e1.LastUpdate.Before(e2.LastUpdate)): - return info.NodeKey1Bytes, [33]byte{} + return node1Bytes, [33]byte{} // Otherwise, we're missing edge2 or edge2 is the older side, so we // return a blank pubkey for edge1. In this case, only an update from // edge2 can resurect the channel. default: - return [33]byte{}, info.NodeKey2Bytes + return [33]byte{}, node2Bytes } } @@ -2578,17 +2597,26 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, return false, ErrEdgeNotFound } + edgeInfo, err := deserializeChanEdgeInfo(bytes.NewReader(nodeInfo)) + if err != nil { + return false, err + } + // Depending on the flags value passed above, either the first // or second edge policy is being updated. - var fromNode, toNode []byte - var isUpdate1 bool + var ( + fromNode, toNode []byte + isUpdate1 bool + node1Bytes = edgeInfo.Node1Bytes() + node2Bytes = edgeInfo.Node2Bytes() + ) if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { - fromNode = nodeInfo[:33] - toNode = nodeInfo[33:66] + fromNode = node1Bytes[:] + toNode = node2Bytes[:] isUpdate1 = true } else { - fromNode = nodeInfo[33:66] - toNode = nodeInfo[:33] + fromNode = node2Bytes[:] + toNode = node1Bytes[:] isUpdate1 = false } @@ -2749,14 +2777,15 @@ func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, nodeIsPublic := false errDone := errors.New("done") err := l.ForEachChannel(db, tx, func(_ kvdb.Backend, _ kvdb.RTx, - info *ChannelEdgeInfo1, _, _ *ChannelEdgePolicy) error { + info models.ChannelEdgeInfo, _, _ *ChannelEdgePolicy) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is // publicly advertised within the graph due to the local node // knowing of the current edge. - if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) && - !bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) { + node1Bytes, node2Bytes := info.Node1Bytes(), info.Node2Bytes() + if !bytes.Equal(node1Bytes[:], sourcePubKey) && + !bytes.Equal(node2Bytes[:], sourcePubKey) { nodeIsPublic = true return errDone @@ -2764,7 +2793,7 @@ func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, // Since the edge _does_ extend to the source node, we'll also // need to ensure that this is a public edge. - if info.AuthProof != nil { + if info.GetAuthProof() != nil { nodeIsPublic = true return errDone } @@ -2858,8 +2887,8 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, - cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { + cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], db, cb) } @@ -2919,8 +2948,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { + cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -2970,9 +2999,19 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, return err } - otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub) - if err != nil { - return err + var ( + otherNode [33]byte + node1Bytes = edgeInfo.Node1Bytes() + node2Bytes = edgeInfo.Node2Bytes() + ) + switch { + case bytes.Equal(node1Bytes[:], nodePub): + otherNode = node2Bytes + case bytes.Equal(node2Bytes[:], nodePub): + otherNode = node1Bytes + default: + return fmt.Errorf("node not participating in " + + "this channel") } incomingPolicy, err := fetchChanEdgePolicy( @@ -2984,7 +3023,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // Finally, we execute the callback. err = cb( - db, tx, &edgeInfo, outgoingPolicy, + db, tx, edgeInfo, outgoingPolicy, incomingPolicy, ) if err != nil { @@ -3020,8 +3059,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, - cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { + cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] @@ -3352,35 +3391,25 @@ func (c *ChannelEdgeInfo1) BitcoinKey2() (*btcec.PublicKey, error) { return key, nil } -// OtherNodeKeyBytes returns the node key bytes of the other end of -// the channel. -func (c *ChannelEdgeInfo1) OtherNodeKeyBytes(thisNodeKey []byte) ( - [33]byte, error) { - - switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - return c.NodeKey2Bytes, nil - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - return c.NodeKey1Bytes, nil - default: - return [33]byte{}, fmt.Errorf("node not participating in this channel") - } -} - // FetchOtherNode attempts to fetch the full LightningNode that's opposite of // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. -func (c *ChannelEdgeInfo1) FetchOtherNode(db kvdb.Backend, tx kvdb.RTx, +func FetchOtherNode(db kvdb.Backend, tx kvdb.RTx, edge models.ChannelEdgeInfo, thisNodeKey []byte) (*LightningNode, error) { + var ( + targetNodeBytes [33]byte + node1Bytes = edge.Node1Bytes() + node2Bytes = edge.Node2Bytes() + ) + // Ensure that the node passed in is actually a member of the channel. - var targetNodeBytes [33]byte switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey2Bytes - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey1Bytes + case bytes.Equal(node1Bytes[:], thisNodeKey): + targetNodeBytes = node2Bytes + case bytes.Equal(node2Bytes[:], thisNodeKey): + targetNodeBytes = node1Bytes default: return nil, fmt.Errorf("node not participating in this channel") } @@ -3681,10 +3710,10 @@ func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { var ( - edgeInfo *ChannelEdgeInfo1 + edgeInfo models.ChannelEdgeInfo policy1 *ChannelEdgePolicy policy2 *ChannelEdgePolicy ) @@ -3730,7 +3759,7 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, if err != nil { return err } - edgeInfo = &edge + edgeInfo = edge // Once we have the information about the channels' parameters, // we'll fetch the routing policies for each for the directed @@ -3765,12 +3794,12 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // // ErrZombieEdge an be returned if the edge is currently marked as a zombie // within the database. In this case, the ChannelEdgePolicy's will be nil, and -// the ChannelEdgeInfo1 will only include the public keys of each node. +// the ChannelEdgeInfo will only include the public keys of each node. func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (*ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { var ( - edgeInfo *ChannelEdgeInfo1 + edgeInfo models.ChannelEdgeInfo policy1 *ChannelEdgePolicy policy2 *ChannelEdgePolicy channelID [8]byte @@ -3835,7 +3864,7 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, return err } - edgeInfo = &edge + edgeInfo = edge // Then we'll attempt to fetch the accompanying policies of this // edge. @@ -3895,26 +3924,6 @@ func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { return nodeIsPublic, nil } -// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. -func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { - witnessScript, err := input.GenMultiSigScript(aPub, bPub) - if err != nil { - return nil, err - } - - // With the witness script generated, we'll now turn it into a p2wsh - // script: - // * OP_0 <sha256(script)> - bldr := txscript.NewScriptBuilder( - txscript.WithScriptAllocSize(input.P2WSHSize), - ) - bldr.AddOp(txscript.OP_0) - scriptHash := sha256.Sum256(witnessScript) - bldr.AddData(scriptHash[:]) - - return bldr.Script() -} - // EdgePoint couples the outpoint of a channel with the funding script that it // creates. The FilteredChainView will use this to watch for spends of this // edge point on chain. We require both of these values as depending on the @@ -3975,10 +3984,7 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return err } - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) + pkScript, err := edgeInfo.FundingScript() if err != nil { return err } @@ -4469,23 +4475,44 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } -func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *ChannelEdgeInfo1, chanID [8]byte) error { +// putChanEdgeInfo encodes and writes the given edge to the edge index bucket. +// The encoding used will depend on the channel type. +func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo models.ChannelEdgeInfo, + chanID [8]byte) error { + var b bytes.Buffer - if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { + switch info := edgeInfo.(type) { + case *ChannelEdgeInfo1: + err := serializeChanEdgeInfo1(&b, info, chanID) + if err != nil { + return err + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgeInfo: %T", edgeInfo) + } + + return edgeIndex.Put(chanID[:], b.Bytes()) +} + +func serializeChanEdgeInfo1(w io.Writer, edgeInfo *ChannelEdgeInfo1, + chanID [8]byte) error { + + if _, err := w.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { + if err := wire.WriteVarBytes(w, 0, edgeInfo.Features); err != nil { return err } @@ -4498,96 +4525,102 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo *ChannelEdgeInfo1, chanID bitcoinSig2 = authProof.BitcoinSig2Bytes } - if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { + if err := wire.WriteVarBytes(w, 0, nodeSig1); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { + if err := wire.WriteVarBytes(w, 0, nodeSig2); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { + if err := wire.WriteVarBytes(w, 0, bitcoinSig1); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { + if err := wire.WriteVarBytes(w, 0, bitcoinSig2); err != nil { return err } - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := writeOutpoint(w, &edgeInfo.ChannelPoint); err != nil { return err } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { + err := binary.Write(w, byteOrder, uint64(edgeInfo.Capacity)) + if err != nil { return err } - if _, err := b.Write(chanID[:]); err != nil { + if _, err := w.Write(chanID[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { + if _, err := w.Write(edgeInfo.ChainHash[:]); err != nil { return err } if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) + err = wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) if err != nil { return err } - return edgeIndex.Put(chanID[:], b.Bytes()) + return wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) } func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, - chanID []byte) (ChannelEdgeInfo1, error) { + chanID []byte) (models.ChannelEdgeInfo, error) { edgeInfoBytes := edgeIndex.Get(chanID) if edgeInfoBytes == nil { - return ChannelEdgeInfo1{}, ErrEdgeNotFound + return nil, ErrEdgeNotFound } edgeInfoReader := bytes.NewReader(edgeInfoBytes) + return deserializeChanEdgeInfo(edgeInfoReader) } -func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo1, error) { +func deserializeChanEdgeInfo(reader io.Reader) (models.ChannelEdgeInfo, error) { + return deserializeChanEdgeInfo1(reader) +} + +func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { var ( err error edgeInfo ChannelEdgeInfo1 ) if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") if err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } proof := &ChannelAuthProof1{} proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if !proof.IsEmpty() { @@ -4596,17 +4629,17 @@ func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo1, error) { edgeInfo.ChannelPoint = wire.OutPoint{} if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { - return ChannelEdgeInfo1{}, err + return nil, err } // We'll try and see if there are any opaque bytes left, if not, then @@ -4618,10 +4651,10 @@ func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo1, error) { case err == io.ErrUnexpectedEOF: case err == io.EOF: case err != nil: - return ChannelEdgeInfo1{}, err + return nil, err } - return edgeInfo, nil + return &edgeInfo, nil } func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy, @@ -4781,24 +4814,29 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { - edgeInfo := edgeIndex.Get(chanID) - if edgeInfo == nil { + edgeInfoBytes := edgeIndex.Get(chanID) + if edgeInfoBytes == nil { return nil, nil, ErrEdgeNotFound } + edgeInfo, err := deserializeChanEdgeInfo(bytes.NewReader(edgeInfoBytes)) + if err != nil { + return nil, nil, err + } + // The first node is contained within the first half of the edge // information. We only propagate the error here and below if it's // something other than edge non-existence. - node1Pub := edgeInfo[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) + node1Pub := edgeInfo.Node1Bytes() + edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub[:], nodes) if err != nil { return nil, nil, err } // Similarly, the second node is contained within the latter // half of the edge information. - node2Pub := edgeInfo[33:66] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) + node2Pub := edgeInfo.Node2Bytes() + edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub[:], nodes) if err != nil { return nil, nil, err } diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 944f4825b6..e4cd9e5c53 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -27,7 +28,7 @@ type GraphCacheNode interface { // error, then the iteration is halted with the error propagated back up // to the caller. ForEachChannel(kvdb.Backend, kvdb.RTx, - func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, + func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error } @@ -223,7 +224,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) return node.ForEachChannel(nil, tx, func(_ kvdb.Backend, tx kvdb.RTx, - info *ChannelEdgeInfo1, outPolicy *ChannelEdgePolicy, + info models.ChannelEdgeInfo, outPolicy *ChannelEdgePolicy, inPolicy *ChannelEdgePolicy) error { c.AddChannel(info, outPolicy, inPolicy) @@ -236,7 +237,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { // and policy 2 does not matter, the directionality is extracted from the info // and policy flags automatically. The policy will be set as the outgoing policy // on one node and the incoming policy on the peer's side. -func (c *GraphCache) AddChannel(info *ChannelEdgeInfo1, +func (c *GraphCache) AddChannel(info models.ChannelEdgeInfo, policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { if info == nil { @@ -251,33 +252,33 @@ func (c *GraphCache) AddChannel(info *ChannelEdgeInfo1, // Create the edge entry for both nodes. c.mtx.Lock() - c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{ - ChannelID: info.ChannelID, + c.updateOrAddEdge(info.Node1Bytes(), &DirectedChannel{ + ChannelID: info.GetChanID(), IsNode1: true, - OtherNode: info.NodeKey2Bytes, - Capacity: info.Capacity, + OtherNode: info.Node2Bytes(), + Capacity: info.GetCapacity(), }) - c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{ - ChannelID: info.ChannelID, + c.updateOrAddEdge(info.Node2Bytes(), &DirectedChannel{ + ChannelID: info.GetChanID(), IsNode1: false, - OtherNode: info.NodeKey1Bytes, - Capacity: info.Capacity, + OtherNode: info.Node1Bytes(), + Capacity: info.GetCapacity(), }) c.mtx.Unlock() // The policy's node is always the to_node. So if policy 1 has to_node // of node 2 then we have the policy 1 as seen from node 1. if policy1 != nil { - fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes - if policy1.Node.PubKeyBytes != info.NodeKey2Bytes { + fromNode, toNode := info.Node1Bytes(), info.Node2Bytes() + if policy1.Node.PubKeyBytes != info.Node2Bytes() { fromNode, toNode = toNode, fromNode } isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0 c.UpdatePolicy(policy1, fromNode, toNode, isEdge1) } if policy2 != nil { - fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes - if policy2.Node.PubKeyBytes != info.NodeKey1Bytes { + fromNode, toNode := info.Node2Bytes(), info.Node1Bytes() + if policy2.Node.PubKeyBytes != info.Node1Bytes() { fromNode, toNode = toNode, fromNode } isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0 @@ -378,28 +379,35 @@ func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { // UpdateChannel updates the channel edge information for a specific edge. We // expect the edge to already exist and be known. If it does not yet exist, this // call is a no-op. -func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo1) { +func (c *GraphCache) UpdateChannel(info models.ChannelEdgeInfo) { c.mtx.Lock() defer c.mtx.Unlock() - if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 || - len(c.nodeChannels[info.NodeKey2Bytes]) == 0 { + var ( + node1Bytes = info.Node1Bytes() + node2Bytes = info.Node2Bytes() + chanID = info.GetChanID() + capacity = info.GetCapacity() + ) + + if len(c.nodeChannels[node1Bytes]) == 0 || + len(c.nodeChannels[node2Bytes]) == 0 { return } - channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID] + channel, ok := c.nodeChannels[node1Bytes][chanID] if ok { // We only expect to be called when the channel is already // known. - channel.Capacity = info.Capacity - channel.OtherNode = info.NodeKey2Bytes + channel.Capacity = capacity + channel.OtherNode = node2Bytes } - channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID] + channel, ok = c.nodeChannels[node2Bytes][chanID] if ok { - channel.Capacity = info.Capacity - channel.OtherNode = info.NodeKey1Bytes + channel.Capacity = capacity + channel.OtherNode = node1Bytes } } diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 9c5671be26..7436451ec8 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "testing" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -41,8 +42,8 @@ func (n *node) Features() *lnwire.FeatureVector { } func (n *node) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, - cb func(kvdb.Backend, kvdb.RTx, *ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { + cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { for idx := range n.edgeInfos { err := cb( diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index c37495003f..22a570b62a 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -506,10 +507,10 @@ func TestDisconnectBlockAtHeight(t *testing.T) { t.Fatalf("expected two edges to be removed from graph, "+ "only %d were", len(removed)) } - if removed[0].ChannelID != edgeInfo.ChannelID { + if removed[0].GetChanID() != edgeInfo.ChannelID { t.Fatalf("expected edge to be removed from graph") } - if removed[1].ChannelID != edgeInfo2.ChannelID { + if removed[1].GetChanID() != edgeInfo2.ChannelID { t.Fatalf("expected edge to be removed from graph") } @@ -553,7 +554,19 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } } -func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo1, +func assertEdgeInfoEqual(t *testing.T, e1, e2 models.ChannelEdgeInfo) { + switch edge1 := e1.(type) { + case *ChannelEdgeInfo1: + edge2, ok := e2.(*ChannelEdgeInfo1) + require.True(t, ok) + + assertEdgeInfo1Equal(t, edge1, edge2) + default: + t.Fatalf("unhandled ChannelEdgeInfo type: %T", e1) + } +} + +func assertEdgeInfo1Equal(t *testing.T, e1 *ChannelEdgeInfo1, e2 *ChannelEdgeInfo1) { if e1.ChannelID != e2.ChannelID { @@ -817,49 +830,54 @@ func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { } func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, - e *ChannelEdgeInfo1) { + e models.ChannelEdgeInfo) { + + var ( + node1Bytes = e.Node1Bytes() + node2Bytes = e.Node2Bytes() + ) // Let's check the internal view first. - require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) - require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey2Bytes]) + require.NotEmpty(t, g.graphCache.nodeChannels[node1Bytes]) + require.NotEmpty(t, g.graphCache.nodeChannels[node2Bytes]) expectedNode1Channel := &DirectedChannel{ - ChannelID: e.ChannelID, + ChannelID: e.GetChanID(), IsNode1: true, - OtherNode: e.NodeKey2Bytes, - Capacity: e.Capacity, + OtherNode: node2Bytes, + Capacity: e.GetCapacity(), OutPolicySet: false, InPolicy: nil, } require.Contains( - t, g.graphCache.nodeChannels[e.NodeKey1Bytes], e.ChannelID, + t, g.graphCache.nodeChannels[node1Bytes], e.GetChanID(), ) require.Equal( t, expectedNode1Channel, - g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID], + g.graphCache.nodeChannels[node1Bytes][e.GetChanID()], ) expectedNode2Channel := &DirectedChannel{ - ChannelID: e.ChannelID, + ChannelID: e.GetChanID(), IsNode1: false, - OtherNode: e.NodeKey1Bytes, - Capacity: e.Capacity, + OtherNode: node1Bytes, + Capacity: e.GetCapacity(), OutPolicySet: false, InPolicy: nil, } require.Contains( - t, g.graphCache.nodeChannels[e.NodeKey2Bytes], e.ChannelID, + t, g.graphCache.nodeChannels[node2Bytes], e.GetChanID(), ) require.Equal( t, expectedNode2Channel, - g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID], + g.graphCache.nodeChannels[node2Bytes][e.GetChanID()], ) // The external view should reflect this as well. var foundChannel *DirectedChannel err := g.graphCache.ForEachChannel( - e.NodeKey1Bytes, func(c *DirectedChannel) error { - if c.ChannelID == e.ChannelID { + node1Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.GetChanID() { foundChannel = c } @@ -871,8 +889,8 @@ func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, require.Equal(t, expectedNode1Channel, foundChannel) err = g.graphCache.ForEachChannel( - e.NodeKey2Bytes, func(c *DirectedChannel) error { - if c.ChannelID == e.ChannelID { + node2Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.GetChanID() { foundChannel = c } @@ -895,10 +913,16 @@ func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { } func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, - e *ChannelEdgeInfo1, p *ChannelEdgePolicy, policy1 bool) { + e models.ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + + var ( + node1Bytes = e.Node1Bytes() + node2Bytes = e.Node2Bytes() + chanID = e.GetChanID() + ) // Check the internal state first. - c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] + c1, ok := g.graphCache.nodeChannels[node1Bytes][chanID] require.True(t, ok) if policy1 { @@ -911,7 +935,7 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, ) } - c2, ok := g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID] + c2, ok := g.graphCache.nodeChannels[node2Bytes][chanID] require.True(t, ok) if policy1 { @@ -930,14 +954,14 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, c2Ext *DirectedChannel ) require.NoError(t, g.graphCache.ForEachChannel( - e.NodeKey1Bytes, func(c *DirectedChannel) error { + node1Bytes, func(c *DirectedChannel) error { c1Ext = c return nil }, )) require.NoError(t, g.graphCache.ForEachChannel( - e.NodeKey2Bytes, func(c *DirectedChannel) error { + node2Bytes, func(c *DirectedChannel) error { c2Ext = c return nil @@ -954,7 +978,7 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, c2Ext.InPolicy.FeeProportionalMillionths, ) require.Equal( - t, route.Vertex(e.NodeKey2Bytes), + t, route.Vertex(node2Bytes), c2Ext.InPolicy.ToNodePubKey(), ) require.Equal(t, testFeatures, c2Ext.InPolicy.ToNodeFeatures) @@ -964,7 +988,7 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, c1Ext.InPolicy.FeeProportionalMillionths, ) require.Equal( - t, route.Vertex(e.NodeKey1Bytes), + t, route.Vertex(node1Bytes), c1Ext.InPolicy.ToNodePubKey(), ) require.Equal(t, testFeatures, c1Ext.InPolicy.ToNodeFeatures) @@ -1037,10 +1061,10 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(func(ei *ChannelEdgeInfo1, _ *ChannelEdgePolicy, - _ *ChannelEdgePolicy) error { + err = graph.ForEachChannel(func(ei models.ChannelEdgeInfo, + _ *ChannelEdgePolicy, _ *ChannelEdgePolicy) error { - delete(chanIndex, ei.ChannelID) + delete(chanIndex, ei.GetChanID()) return nil }) require.NoError(t, err) @@ -1051,28 +1075,37 @@ func TestGraphTraversal(t *testing.T) { numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(graph.DB(), nil, - func(_ kvdb.Backend, _ kvdb.RTx, _ *ChannelEdgeInfo1, + func(_ kvdb.Backend, _ kvdb.RTx, _ models.ChannelEdgeInfo, outEdge, inEdge *ChannelEdgePolicy) error { - // All channels between first and second node should have fully - // (both sides) specified policies. + // All channels between first and second node should + // have fully (both sides) specified policies. if inEdge == nil || outEdge == nil { return fmt.Errorf("channel policy not present") } // Each should indicate that it's outgoing (pointed // towards the second node). - if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { + if !bytes.Equal( + outEdge.Node.PubKeyBytes[:], + secondNode.PubKeyBytes[:], + ) { + return fmt.Errorf("wrong outgoing edge") } - // The incoming edge should also indicate that it's pointing to - // the origin node. - if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { + // The incoming edge should also indicate that it's + // pointing to the origin node. + if !bytes.Equal( + inEdge.Node.PubKeyBytes[:], + firstNode.PubKeyBytes[:], + ) { + return fmt.Errorf("wrong outgoing edge") } numNodeChans++ + return nil }) require.NoError(t, err) @@ -1125,11 +1158,11 @@ func TestGraphTraversalCacheable(t *testing.T) { for _, node := range nodes { err := node.ForEachChannel( graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, - info *ChannelEdgeInfo1, + info models.ChannelEdgeInfo, _ *ChannelEdgePolicy, _ *ChannelEdgePolicy) error { - delete(chanIndex, info.ChannelID) + delete(chanIndex, info.GetChanID()) return nil }, @@ -1310,8 +1343,8 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 - if err := graph.ForEachChannel(func(*ChannelEdgeInfo1, *ChannelEdgePolicy, - *ChannelEdgePolicy) error { + if err := graph.ForEachChannel(func(models.ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error { numChans++ return nil @@ -1447,9 +1480,7 @@ func TestGraphPruning(t *testing.T) { t.Fatalf("unable to add node: %v", err) } - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], - ) + pkScript, err := edgeInfo.FundingScript() if err != nil { t.Fatalf("unable to gen multi-sig p2wsh: %v", err) } @@ -2278,8 +2309,9 @@ func TestIncompleteChannelPolicies(t *testing.T) { checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 err := node.ForEachChannel(graph.DB(), nil, - func(_ kvdb.Backend, _ kvdb.RTx, _ *ChannelEdgeInfo1, - outEdge, inEdge *ChannelEdgePolicy) error { + func(_ kvdb.Backend, _ kvdb.RTx, + _ models.ChannelEdgeInfo, outEdge, + inEdge *ChannelEdgePolicy) error { if !expectedOut && outEdge != nil { t.Fatalf("Expected no outgoing policy") @@ -3430,7 +3462,7 @@ func BenchmarkForEachChannel(b *testing.B) { err := n.ForEachChannel( graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, - info *ChannelEdgeInfo1, + info models.ChannelEdgeInfo, policy *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) error { @@ -3439,7 +3471,8 @@ func BenchmarkForEachChannel(b *testing.B) { // compiler is going to optimize // this away, and we get bogus // results. - totalCapacity += info.Capacity + capacity := info.GetCapacity() + totalCapacity += capacity maxHTLCs += policy.MaxHTLC maxHTLCs += policy2.MaxHTLC diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 0fe819d4ef..891c660e9e 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -118,13 +118,13 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // If the channel hasn't been fully advertised yet, or is a // private channel, then we'll skip it as we can't construct a // full authentication proof if one is requested. - if channel.Info.AuthProof == nil { + if channel.Info.GetAuthProof() == nil { continue } chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement( - channel.Info.AuthProof, channel.Info, channel.Policy1, - channel.Policy2, + channel.Info.GetAuthProof(), channel.Info, + channel.Policy1, channel.Policy2, ) if err != nil { return nil, err @@ -260,13 +260,13 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, // If the channel doesn't have an authentication proof, then we // won't send it over as it may not yet be finalized, or be a // non-advertised channel. - if channel.Info.AuthProof == nil { + if channel.Info.GetAuthProof() == nil { continue } chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement( - channel.Info.AuthProof, channel.Info, channel.Policy1, - channel.Policy2, + channel.Info.GetAuthProof(), channel.Info, + channel.Policy1, channel.Policy2, ) if err != nil { return nil, err diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 089e5319e4..de0ef16be9 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" @@ -527,7 +528,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper // EdgeWithInfo contains the information that is required to update an edge. type EdgeWithInfo struct { // Info describes the channel. - Info *channeldb.ChannelEdgeInfo1 + Info models.ChannelEdgeInfo // Edge describes the policy in one direction of the channel. Edge *channeldb.ChannelEdgePolicy @@ -1579,7 +1580,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // Iterate over all of our channels and check if any of them fall // within the prune interval or re-broadcast interval. type updateTuple struct { - info *channeldb.ChannelEdgeInfo1 + info models.ChannelEdgeInfo edge *channeldb.ChannelEdgePolicy } @@ -1588,8 +1589,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { edgesToUpdate []updateTuple ) err := d.cfg.Router.ForAllOutgoingChannels(func( - _ kvdb.RTx, - info *channeldb.ChannelEdgeInfo1, + _ kvdb.RTx, info models.ChannelEdgeInfo, edge *channeldb.ChannelEdgePolicy) error { // If there's no auth proof attached to this edge, it means @@ -1597,9 +1597,9 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // the greater network, so avoid sending channel updates for // this channel to not leak its // existence. - if info.AuthProof == nil { + if info.GetAuthProof() == nil { log.Debugf("Skipping retransmission of channel "+ - "without AuthProof: %v", info.ChannelID) + "without AuthProof: %v", info.GetChanID()) return nil } @@ -1615,7 +1615,9 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // We'll make sure we support the new max_htlc field if // not already present. edge.MessageFlags |= lnwire.ChanUpdateRequiredMaxHtlc - edge.MaxHTLC = lnwire.NewMSatFromSatoshis(info.Capacity) + edge.MaxHTLC = lnwire.NewMSatFromSatoshis( + info.GetCapacity(), + ) edgesToUpdate = append(edgesToUpdate, updateTuple{ info: info, @@ -1733,15 +1735,15 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( // We'll avoid broadcasting any updates for private channels to // avoid directly giving away their existence. Instead, we'll // send the update directly to the remote party. - if edgeInfo.Info.AuthProof == nil { + if edgeInfo.Info.GetAuthProof() == nil { // If AuthProof is nil and an alias was found for this // ChannelID (meaning the option-scid-alias feature was // negotiated), we'll replace the ShortChannelID in the // update with the peer's alias. We do this after // updateChannel so that the alias isn't persisted to // the database. - op := &edgeInfo.Info.ChannelPoint - chanID := lnwire.NewChanIDFromOutPoint(op) + op := edgeInfo.Info.GetChanPoint() + chanID := lnwire.NewChanIDFromOutPoint(&op) var defaultAlias lnwire.ShortChannelID foundAlias, _ := d.cfg.GetAlias(chanID) @@ -1795,15 +1797,15 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( // remotePubFromChanInfo returns the public key of the remote peer given a // ChannelEdgeInfo1 that describe a channel we have with them. -func remotePubFromChanInfo(chanInfo *channeldb.ChannelEdgeInfo1, +func remotePubFromChanInfo(chanInfo models.ChannelEdgeInfo, chanFlags lnwire.ChanUpdateChanFlags) [33]byte { var remotePubKey [33]byte switch { case chanFlags&lnwire.ChanUpdateDirection == 0: - remotePubKey = chanInfo.NodeKey2Bytes + remotePubKey = chanInfo.Node2Bytes() case chanFlags&lnwire.ChanUpdateDirection == 1: - remotePubKey = chanInfo.NodeKey1Bytes + remotePubKey = chanInfo.Node1Bytes() } return remotePubKey @@ -1831,7 +1833,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( // The edge is in the graph, and has a proof attached, then we'll just // reject it as normal. - if chanInfo.AuthProof != nil { + if chanInfo.GetAuthProof() != nil { return nil, nil } @@ -2026,7 +2028,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // processZombieUpdate determines whether the provided channel update should // resurrect a given zombie edge. func (d *AuthenticatedGossiper) processZombieUpdate( - chanInfo *channeldb.ChannelEdgeInfo1, msg *lnwire.ChannelUpdate) error { + chanInfo models.ChannelEdgeInfo, msg *lnwire.ChannelUpdate) error { // The least-significant bit in the flag on the channel update tells us // which edge is being updated. @@ -2039,9 +2041,9 @@ func (d *AuthenticatedGossiper) processZombieUpdate( // will only have the pubkey of the node with the oldest timestamp. var pubKey *btcec.PublicKey switch { - case isNode1 && chanInfo.NodeKey1Bytes != emptyPubkey: + case isNode1 && chanInfo.Node1Bytes() != emptyPubkey: pubKey, _ = chanInfo.NodeKey1() - case !isNode1 && chanInfo.NodeKey2Bytes != emptyPubkey: + case !isNode1 && chanInfo.Node2Bytes() != emptyPubkey: pubKey, _ = chanInfo.NodeKey2() } if pubKey == nil { @@ -2058,7 +2060,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( // With the signature valid, we'll proceed to mark the // edge as live and wait for the channel announcement to // come through again. - baseScid := lnwire.NewShortChanIDFromInt(chanInfo.ChannelID) + baseScid := lnwire.NewShortChanIDFromInt(chanInfo.GetChanID()) err = d.cfg.Router.MarkEdgeLive(baseScid) if err != nil { return fmt.Errorf("unable to remove edge with "+ @@ -2102,14 +2104,14 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { } if err != nil { log.Debugf("Unable to retrieve channel=%v from graph: "+ - "%v", chanInfo.ChannelID, err) + "%v", chanInfo.GetChanID(), err) return false } // If the proof exists in the graph, then we have successfully // received the remote proof and assembled the full proof, so we // can safely delete the local proof from the database. - return chanInfo.AuthProof != nil + return chanInfo.GetAuthProof() != nil case *lnwire.ChannelUpdate: _, p1, p2, err := d.cfg.Router.GetChannelByID(msg.ShortChannelID) @@ -2154,12 +2156,14 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. -func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, +func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, edge *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate, error) { // Parse the unsigned edge into a channel update. - chanUpdate := netann.UnsignedChannelUpdateFromEdge(info, edge) + chanUpdate := netann.UnsignedChannelUpdateFromEdge( + edgeInfo.GetChainHash(), edge, + ) // We'll generate a new signature over a digest of the channel // announcement itself and update the timestamp to ensure it propagate. @@ -2179,7 +2183,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. err = routing.ValidateChannelUpdateAnn( - d.selfKey, info.Capacity, chanUpdate, + d.selfKey, edgeInfo.GetCapacity(), chanUpdate, ) if err != nil { return nil, nil, fmt.Errorf("generated invalid channel "+ @@ -2194,48 +2198,67 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo1, // We'll also create the original channel announcement so the two can // be broadcast along side each other (if necessary), but only if we // have a full channel announcement for this channel. - var chanAnn *lnwire.ChannelAnnouncement1 - if info.AuthProof != nil { - chanID := lnwire.NewShortChanIDFromInt(info.ChannelID) - chanAnn = &lnwire.ChannelAnnouncement1{ - ShortChannelID: chanID, - NodeID1: info.NodeKey1Bytes, - NodeID2: info.NodeKey2Bytes, - ChainHash: info.ChainHash, - BitcoinKey1: info.BitcoinKey1Bytes, - Features: lnwire.NewRawFeatureVector(), - BitcoinKey2: info.BitcoinKey2Bytes, - ExtraOpaqueData: edge.ExtraOpaqueData, - } - chanAnn.NodeSig1, err = lnwire.NewSigFromECDSARawSignature( - info.AuthProof.NodeSig1Bytes, - ) - if err != nil { - return nil, nil, err - } - chanAnn.NodeSig2, err = lnwire.NewSigFromECDSARawSignature( - info.AuthProof.NodeSig2Bytes, - ) - if err != nil { - return nil, nil, err - } - chanAnn.BitcoinSig1, err = lnwire.NewSigFromECDSARawSignature( - info.AuthProof.BitcoinSig1Bytes, - ) - if err != nil { - return nil, nil, err - } - chanAnn.BitcoinSig2, err = lnwire.NewSigFromECDSARawSignature( - info.AuthProof.BitcoinSig2Bytes, - ) - if err != nil { - return nil, nil, err + var chanAnn lnwire.ChannelAnnouncement + switch info := edgeInfo.(type) { + case *channeldb.ChannelEdgeInfo1: + if info.AuthProof != nil { + chanAnn, err = chanAnn1FromEdgeInfo1(info) + if err != nil { + return nil, nil, err + } } + default: + return nil, nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelEdgeInfo: %T", info) } return chanAnn, chanUpdate, err } +func chanAnn1FromEdgeInfo1(info *channeldb.ChannelEdgeInfo1) ( + lnwire.ChannelAnnouncement, error) { + + var err error + + chanID := lnwire.NewShortChanIDFromInt(info.ChannelID) + chanAnn := &lnwire.ChannelAnnouncement1{ + ShortChannelID: chanID, + NodeID1: info.NodeKey1Bytes, + NodeID2: info.NodeKey2Bytes, + ChainHash: info.ChainHash, + BitcoinKey1: info.BitcoinKey1Bytes, + Features: lnwire.NewRawFeatureVector(), + BitcoinKey2: info.BitcoinKey2Bytes, + ExtraOpaqueData: info.ExtraOpaqueData, + } + chanAnn.NodeSig1, err = lnwire.NewSigFromECDSARawSignature( + info.AuthProof.NodeSig1Bytes, + ) + if err != nil { + return nil, err + } + chanAnn.NodeSig2, err = lnwire.NewSigFromECDSARawSignature( + info.AuthProof.NodeSig2Bytes, + ) + if err != nil { + return nil, err + } + chanAnn.BitcoinSig1, err = lnwire.NewSigFromECDSARawSignature( + info.AuthProof.BitcoinSig1Bytes, + ) + if err != nil { + return nil, err + } + chanAnn.BitcoinSig2, err = lnwire.NewSigFromECDSARawSignature( + info.AuthProof.BitcoinSig2Bytes, + ) + if err != nil { + return nil, err + } + + return chanAnn, nil +} + // SyncManager returns the gossiper's SyncManager instance. func (d *AuthenticatedGossiper) SyncManager() *SyncManager { return d.syncMgr @@ -2467,7 +2490,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // With the proof validated (if necessary), we can now store it within // the database for our path finding and syncing needs. - edge, err := buildEdgeInfo(ann, nMsg.optionalMsgFields, proof) + edge, err := buildEdgeInfo(ann, nMsg.optionalMsgFields) if err != nil { log.Errorf("unable to build edge info from announcement: %v", err) @@ -2476,6 +2499,14 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, false } + err = edge.SetAuthProof(proof) + if err != nil { + log.Errorf("unable to set auth proof: %v", err) + nMsg.err <- err + + return nil, false + } + log.Debugf("Adding edge for short_chan_id: %v", scid.ToUint64()) // We will add the edge to the channel router. If the nodes present in @@ -2812,13 +2843,15 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } log.Debugf("Validating ChannelUpdate: channel=%v, from node=%x, has "+ - "edge=%v", chanInfo.ChannelID, pubKey.SerializeCompressed(), + "edge=%v", chanInfo.GetChanID(), pubKey.SerializeCompressed(), edgeToUpdate != nil) // Validate the channel announcement with the expected public key and // channel capacity. In the case of an invalid channel update, we'll // return an error to the caller and exit early. - err = routing.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, upd) + err = routing.ValidateChannelUpdateAnn( + pubKey, chanInfo.GetCapacity(), upd, + ) if err != nil { rErr := fmt.Errorf("unable to validate channel update "+ "announcement for short_chan_id=%v: %v", @@ -2857,7 +2890,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // multiple aliases for a channel and we may otherwise // rate-limit only a single alias of the channel, // instead of the whole channel. - baseScid := chanInfo.ChannelID + baseScid := chanInfo.GetChanID() d.Lock() rls, ok := d.chanUpdateRateLimiter[baseScid] if !ok { @@ -2890,7 +2923,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // only be a difference if AuthProof == nil, this is fine. update := &channeldb.ChannelEdgePolicy{ SigBytes: upd.Signature.ToSignatureBytes(), - ChannelID: chanInfo.ChannelID, + ChannelID: chanInfo.GetChanID(), LastUpdate: timestamp, MessageFlags: upd.MessageFlags, ChannelFlags: upd.ChannelFlags, @@ -2915,7 +2948,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // Since we know the stored SCID in the graph, we'll // cache that SCID. key := newRejectCacheKey( - chanInfo.ChannelID, + chanInfo.GetChanID(), sourceToPub(nMsg.source), ) _, _ = d.recentRejects.Put(key, &cachedReject{}) @@ -2930,37 +2963,37 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // If this is a local ChannelUpdate without an AuthProof, it means it // is an update to a channel that is not (yet) supposed to be announced - // to the greater network. However, our channel counter party will need + // to the greater network. However, our channel counterparty will need // to be given the update, so we'll try sending the update directly to // the remote peer. - if !nMsg.isRemote && chanInfo.AuthProof == nil { - if nMsg.optionalMsgFields != nil { + if !nMsg.isRemote && chanInfo.GetAuthProof() == nil { + if nMsg.optionalMsgFields != nil && + nMsg.optionalMsgFields.remoteAlias != nil { + + // The remoteAlias field was specified, meaning + // that we should replace the SCID in the + // update with the remote's alias. We'll also + // need to re-sign the channel update. This is + // required for option-scid-alias feature-bit + // negotiated channels. remoteAlias := nMsg.optionalMsgFields.remoteAlias - if remoteAlias != nil { - // The remoteAlias field was specified, meaning - // that we should replace the SCID in the - // update with the remote's alias. We'll also - // need to re-sign the channel update. This is - // required for option-scid-alias feature-bit - // negotiated channels. - upd.ShortChannelID = *remoteAlias - - sig, err := d.cfg.SignAliasUpdate(upd) - if err != nil { - log.Error(err) - nMsg.err <- err - return nil, false - } + upd.ShortChannelID = *remoteAlias - lnSig, err := lnwire.NewSigFromSignature(sig) - if err != nil { - log.Error(err) - nMsg.err <- err - return nil, false - } + sig, err := d.cfg.SignAliasUpdate(upd) + if err != nil { + log.Error(err) + nMsg.err <- err + return nil, false + } - upd.Signature = lnSig + lnSig, err := lnwire.NewSigFromSignature(sig) + if err != nil { + log.Error(err) + nMsg.err <- err + return nil, false } + + upd.Signature = lnSig } // Get our peer's public key. @@ -2990,7 +3023,9 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // authentication proof. We also won't broadcast the update if it // contains an alias because the network would reject this. var announcements []networkMsg - if chanInfo.AuthProof != nil && !d.cfg.IsAlias(upd.ShortChannelID) { + if chanInfo.GetAuthProof() != nil && + !d.cfg.IsAlias(upd.ShortChannelID) { + announcements = append(announcements, networkMsg{ peer: nMsg.peer, source: nMsg.source, @@ -3079,9 +3114,14 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, return nil, false } + var ( + node1Bytes = chanInfo.Node1Bytes() + node2Bytes = chanInfo.Node2Bytes() + ) + nodeID := nMsg.source.SerializeCompressed() - isFirstNode := bytes.Equal(nodeID, chanInfo.NodeKey1Bytes[:]) - isSecondNode := bytes.Equal(nodeID, chanInfo.NodeKey2Bytes[:]) + isFirstNode := bytes.Equal(nodeID, node1Bytes[:]) + isSecondNode := bytes.Equal(nodeID, node2Bytes[:]) // Ensure that channel that was retrieved belongs to the peer which // sent the proof announcement. @@ -3100,9 +3140,9 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, if !nMsg.isRemote { var remotePubKey [33]byte if isFirstNode { - remotePubKey = chanInfo.NodeKey2Bytes + remotePubKey = node2Bytes } else { - remotePubKey = chanInfo.NodeKey1Bytes + remotePubKey = node1Bytes } // Since the remote peer might not be online we'll call a @@ -3119,7 +3159,8 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, } // Check if we already have the full proof for this channel. - if chanInfo.AuthProof != nil { + authProof := chanInfo.GetAuthProof() + if authProof != nil { // If we already have the fully assembled proof, then the peer // sending us their proof has probably not received our local // proof yet. So be kind and send them the full proof. @@ -3138,7 +3179,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, ann.ChannelID, peerID) ca, _, _, err := netann.CreateChanAnnouncement( - chanInfo.AuthProof, chanInfo, e1, e2, + authProof, chanInfo, e1, e2, ) if err != nil { log.Errorf("unable to gen ann: %v", @@ -3292,10 +3333,10 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // it since the source gets skipped. This isn't necessary for channel // updates and announcement signatures since we send those directly to // our channel counterparty through the gossiper's reliable sender. - node1Ann, err := d.fetchNodeAnn(chanInfo.NodeKey1Bytes) + node1Ann, err := d.fetchNodeAnn(node1Bytes) if err != nil { log.Debugf("Unable to fetch node announcement for %x: %v", - chanInfo.NodeKey1Bytes, err) + node1Bytes, err) } else { if nodeKey1, err := chanInfo.NodeKey1(); err == nil { announcements = append(announcements, networkMsg{ @@ -3306,10 +3347,10 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, } } - node2Ann, err := d.fetchNodeAnn(chanInfo.NodeKey2Bytes) + node2Ann, err := d.fetchNodeAnn(node2Bytes) if err != nil { log.Debugf("Unable to fetch node announcement for %x: %v", - chanInfo.NodeKey2Bytes, err) + node2Bytes, err) } else { if nodeKey2, err := chanInfo.NodeKey2(); err == nil { announcements = append(announcements, networkMsg{ @@ -3341,9 +3382,8 @@ func buildChanProof(ann lnwire.ChannelAnnouncement) ( } } -func buildEdgeInfo(ann lnwire.ChannelAnnouncement, opts *optionalMsgFields, - proof *channeldb.ChannelAuthProof1) (*channeldb.ChannelEdgeInfo1, - error) { +func buildEdgeInfo(ann lnwire.ChannelAnnouncement, opts *optionalMsgFields) ( + models.ChannelEdgeInfo, error) { switch a := ann.(type) { case *lnwire.ChannelAnnouncement1: @@ -3360,7 +3400,6 @@ func buildEdgeInfo(ann lnwire.ChannelAnnouncement, opts *optionalMsgFields, BitcoinKey1Bytes: a.BitcoinKey1, BitcoinKey2Bytes: a.BitcoinKey2, Features: featureBuf.Bytes(), - AuthProof: proof, ExtraOpaqueData: a.ExtraOpaqueData, } diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 26a7015657..8f964a3749 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" @@ -91,7 +92,7 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode - infos map[uint64]channeldb.ChannelEdgeInfo1 + infos map[uint64]models.ChannelEdgeInfo edges map[uint64][]channeldb.ChannelEdgePolicy zombies map[uint64][][33]byte chansToReject map[uint64]struct{} @@ -100,7 +101,7 @@ type mockGraphSource struct { func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, - infos: make(map[uint64]channeldb.ChannelEdgeInfo1), + infos: make(map[uint64]models.ChannelEdgeInfo), edges: make(map[uint64][]channeldb.ChannelEdgePolicy), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), @@ -119,21 +120,22 @@ func (r *mockGraphSource) AddNode(node *channeldb.LightningNode, return nil } -func (r *mockGraphSource) AddEdge(info *channeldb.ChannelEdgeInfo1, +func (r *mockGraphSource) AddEdge(info models.ChannelEdgeInfo, _ ...batch.SchedulerOption) error { r.mu.Lock() defer r.mu.Unlock() - if _, ok := r.infos[info.ChannelID]; ok { + if _, ok := r.infos[info.GetChanID()]; ok { return errors.New("info already exist") } - if _, ok := r.chansToReject[info.ChannelID]; ok { + if _, ok := r.chansToReject[info.GetChanID()]; ok { return errors.New("validation failed") } - r.infos[info.ChannelID] = *info + r.infos[info.GetChanID()] = info + return nil } @@ -168,7 +170,7 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) { } func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof1) error { + proof models.ChannelAuthProof) error { r.mu.Lock() defer r.mu.Unlock() @@ -179,8 +181,14 @@ func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, return errors.New("channel does not exist") } - info.AuthProof = proof - r.infos[chanIDInt] = info + infoCP := info.Copy() + + err := infoCP.SetAuthProof(proof) + if err != nil { + return err + } + + r.infos[chanIDInt] = infoCP return nil } @@ -190,8 +198,7 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) } func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i *channeldb.ChannelEdgeInfo1, - c *channeldb.ChannelEdgePolicy) error) error { + i models.ChannelEdgeInfo, c *channeldb.ChannelEdgePolicy) error) error { r.mu.Lock() defer r.mu.Unlock() @@ -200,9 +207,9 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, for _, info := range r.infos { info := info - edgeInfo := chans[info.ChannelID] - edgeInfo.Info = &info - chans[info.ChannelID] = edgeInfo + edgeInfo := chans[info.GetChanID()] + edgeInfo.Info = info + chans[info.GetChanID()] = edgeInfo } for _, edges := range r.edges { edges := edges @@ -221,14 +228,14 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, return nil } -func (r *mockGraphSource) ForEachChannel(func(chanInfo *channeldb.ChannelEdgeInfo1, +func (r *mockGraphSource) ForEachChannel(_ func(chanInfo models.ChannelEdgeInfo, e1, e2 *channeldb.ChannelEdgePolicy) error) error { + return nil } func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo1, - *channeldb.ChannelEdgePolicy, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { r.mu.Lock() @@ -248,9 +255,11 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( }, nil, nil, channeldb.ErrZombieEdge } + chanInfoCP := chanInfo.Copy() + edges := r.edges[chanID.ToUint64()] if len(edges) == 0 { - return &chanInfo, nil, nil, nil + return chanInfoCP, nil, nil, nil } var edge1 *channeldb.ChannelEdgePolicy @@ -263,7 +272,7 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( edge2 = &edges[1] } - return &chanInfo, edge1, edge2, nil + return chanInfoCP, edge1, edge2, nil } func (r *mockGraphSource) FetchLightningNode( @@ -295,10 +304,10 @@ func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time) // require the node to already have a channel in the graph to not be // considered stale. for _, info := range r.infos { - if info.NodeKey1Bytes == nodePub { + if info.Node1Bytes() == nodePub { return false } - if info.NodeKey2Bytes == nodePub { + if info.Node2Bytes() == nodePub { return false } } @@ -309,12 +318,16 @@ func (r *mockGraphSource) IsStaleNode(nodePub route.Vertex, timestamp time.Time) // the graph from the graph's source node's point of view. func (r *mockGraphSource) IsPublicNode(node route.Vertex) (bool, error) { for _, info := range r.infos { - if !bytes.Equal(node[:], info.NodeKey1Bytes[:]) && - !bytes.Equal(node[:], info.NodeKey2Bytes[:]) { + n1 := info.Node1Bytes() + n2 := info.Node2Bytes() + + if !bytes.Equal(node[:], n1[:]) && + !bytes.Equal(node[:], n2[:]) { + continue } - if info.AuthProof != nil { + if info.GetAuthProof() != nil { return true, nil } } @@ -3442,7 +3455,7 @@ out: var edgesToUpdate []EdgeWithInfo err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, - info *channeldb.ChannelEdgeInfo1, + info models.ChannelEdgeInfo, edge *channeldb.ChannelEdgePolicy) error { edge.TimeLockDelta = uint16(newTimeLockDelta) @@ -3550,17 +3563,9 @@ func TestProcessChannelAnnouncementOptionalMsgFields(t *testing.T) { t.Helper() edge, _, _, err := ctx.router.GetChannelByID(chanID) - if err != nil { - t.Fatalf("unable to get channel by id: %v", err) - } - if edge.Capacity != capacity { - t.Fatalf("expected capacity %v, got %v", capacity, - edge.Capacity) - } - if edge.ChannelPoint != channelPoint { - t.Fatalf("expected channel point %v, got %v", - channelPoint, edge.ChannelPoint) - } + require.NoError(t, err) + require.Equal(t, capacity, edge.GetCapacity()) + require.Equal(t, channelPoint, edge.GetChanPoint()) } // We'll process the first announcement without any optional fields. We diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 5396579d70..050871c9c0 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -17,6 +17,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -544,8 +545,11 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( // Now, we'll need to determine which is the correct policy for HTLCs // being sent from the remote node. - var remotePolicy *channeldb.ChannelEdgePolicy - if bytes.Equal(remotePub[:], info.NodeKey1Bytes[:]) { + var ( + remotePolicy *channeldb.ChannelEdgePolicy + node1Bytes = info.Node1Bytes() + ) + if bytes.Equal(remotePub[:], node1Bytes[:]) { remotePolicy = p1 } else { remotePolicy = p2 @@ -627,7 +631,7 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*channeldb.ChannelEdgeInfo1, + FetchChannelEdgesByID func(chanID uint64) (models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 0cc1f9c516..223a24495c 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" "github.com/stretchr/testify/mock" @@ -51,7 +52,7 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { args := h.Mock.Called(chanID) @@ -64,7 +65,12 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( return nil, nil, nil, err } - edgeInfo := args.Get(0).(*channeldb.ChannelEdgeInfo1) + edgeInfo, ok := args.Get(0).(*channeldb.ChannelEdgeInfo1) + if !ok { + return nil, nil, nil, fmt.Errorf("unexpected "+ + "ChannelEdgeInfo impl received: %T", args.Get(0)) + } + policy1 := args.Get(1).(*channeldb.ChannelEdgePolicy) policy2 := args.Get(2).(*channeldb.ChannelEdgePolicy) diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 7ee543d2ac..43b21fda7d 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -661,7 +661,8 @@ func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( update, err := ExtractChannelUpdate( m.ourPubKeyBytes, info, edge1, edge2, ) - return update, info.AuthProof == nil, err + + return update, info.GetAuthProof() == nil, err } // loadInitialChanState determines the initial ChannelState for a particular diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 65e0bf70bd..e4cedd25ab 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" @@ -160,7 +161,7 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { } func (g *mockGraph) FetchChannelEdgesByOutpoint( - op *wire.OutPoint) (*channeldb.ChannelEdgeInfo1, + op *wire.OutPoint) (models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { g.mu.Lock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index e383286990..ee3e4791c2 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -2,8 +2,10 @@ package netann import ( "bytes" + "fmt" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" ) @@ -12,7 +14,29 @@ import ( // function is used to transform out database structs into the corresponding wire // structs for announcing new channels to other peers, or simply syncing up a // peer's initial routing table upon connect. -func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof1, +func CreateChanAnnouncement(chanProof models.ChannelAuthProof, + chanInfo models.ChannelEdgeInfo, + e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { + + switch proof := chanProof.(type) { + case *channeldb.ChannelAuthProof1: + info, ok := chanInfo.(*channeldb.ChannelEdgeInfo1) + if !ok { + return nil, nil, nil, fmt.Errorf("expected type "+ + "ChannelEdgeInfo1 to be paired with "+ + "ChannelAuthProof1, got: %T", chanInfo) + } + + return createChanAnnouncement1(proof, info, e1, e2) + + default: + return nil, nil, nil, fmt.Errorf("unhandled "+ + "channeldb.ChannelAuthProof type: %T", chanProof) + } +} + +func createChanAnnouncement1(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { diff --git a/netann/channel_update.go b/netann/channel_update.go index 9880a462b3..0e64613476 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -6,7 +6,9 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -84,7 +86,7 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, - info *channeldb.ChannelEdgeInfo1, + info models.ChannelEdgeInfo, policies ...*channeldb.ChannelEdgePolicy) ( *lnwire.ChannelUpdate, error) { @@ -117,11 +119,11 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the // given edge info and policy. -func UnsignedChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo1, +func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, policy *channeldb.ChannelEdgePolicy) *lnwire.ChannelUpdate { return &lnwire.ChannelUpdate{ - ChainHash: info.ChainHash, + ChainHash: chainHash, ShortChannelID: lnwire.NewShortChanIDFromInt(policy.ChannelID), Timestamp: uint32(policy.LastUpdate.Unix()), ChannelFlags: policy.ChannelFlags, @@ -137,10 +139,10 @@ func UnsignedChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo1, // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given edge // info and policy. -func ChannelUpdateFromEdge(info *channeldb.ChannelEdgeInfo1, +func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate, error) { - update := UnsignedChannelUpdateFromEdge(info, policy) + update := UnsignedChannelUpdateFromEdge(info.GetChainHash(), policy) var err error update.Signature, err = lnwire.NewSigFromECDSARawSignature( diff --git a/netann/interface.go b/netann/interface.go index 68afdd32e9..cdeefc0987 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -3,6 +3,7 @@ package netann import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" ) // DB abstracts the required database functionality needed by the @@ -18,6 +19,6 @@ type DB interface { type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. - FetchChannelEdgesByOutpoint(*wire.OutPoint) (*channeldb.ChannelEdgeInfo1, + FetchChannelEdgesByOutpoint(*wire.OutPoint) (models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) } diff --git a/peer/brontide.go b/peer/brontide.go index 349cdab794..32e88fade5 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -933,13 +933,14 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( // // TODO(roasbeef): can add helper method to get policy for // particular channel. - var selfPolicy *channeldb.ChannelEdgePolicy - if info != nil && bytes.Equal(info.NodeKey1Bytes[:], - p.cfg.ServerPubKey[:]) { + selfPolicy := p2 + if info != nil { + node1Bytes := info.Node1Bytes() + if bytes.Equal(node1Bytes[:], + p.cfg.ServerPubKey[:]) { - selfPolicy = p1 - } else { - selfPolicy = p2 + selfPolicy = p1 + } } // If we don't yet have an advertised routing policy, then diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index b450e14a14..221ed69cb0 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -31,7 +31,7 @@ type Manager struct { // ForAllOutgoingChannels is required to iterate over all our local // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo1, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy) error) error // FetchChannel is used to query local channel parameters. Optionally an @@ -73,25 +73,27 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // otherwise we'll collect them all. err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, - info *channeldb.ChannelEdgeInfo1, + info models.ChannelEdgeInfo, edge *channeldb.ChannelEdgePolicy) error { + chanPoint := info.GetChanPoint() + // If we have a channel filter, and this channel isn't a part // of it, then we'll skip it. - _, ok := unprocessedChans[info.ChannelPoint] + _, ok := unprocessedChans[chanPoint] if !ok && haveChanFilter { return nil } // Mark this channel as found by removing it. unprocessedChans // will be used to report invalid channels later on. - delete(unprocessedChans, info.ChannelPoint) + delete(unprocessedChans, chanPoint) // Apply the new policy to the edge. - err := r.updateEdge(tx, info.ChannelPoint, edge, newSchema) + err := r.updateEdge(tx, chanPoint, edge, newSchema) if err != nil { failedUpdates = append(failedUpdates, - makeFailureItem(info.ChannelPoint, + makeFailureItem(chanPoint, lnrpc.UpdateFailure_UPDATE_FAILURE_INVALID_PARAMETER, err.Error(), )) @@ -106,7 +108,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, }) // Add updated policy to list of policies to send to switch. - policiesToUpdate[info.ChannelPoint] = models.ForwardingPolicy{ + policiesToUpdate[chanPoint] = models.ForwardingPolicy{ BaseFee: edge.FeeBaseMSat, FeeRate: edge.FeeProportionalMillionths, TimeLockDelta: uint32(edge.TimeLockDelta), diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index ca59012dcb..4409dff620 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -107,7 +107,7 @@ func TestManager(t *testing.T) { } forAllOutgoingChannels := func(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo1, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy) error) error { for _, c := range channelSet { diff --git a/routing/notifications.go b/routing/notifications.go index 85aeb838bb..b0ff5c79c5 100644 --- a/routing/notifications.go +++ b/routing/notifications.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" ) @@ -212,15 +213,15 @@ type ClosedChanSummary struct { // createCloseSummaries takes in a slice of channels closed at the target block // height and creates a slice of summaries which of each channel closure. func createCloseSummaries(blockHeight uint32, - closedChans ...*channeldb.ChannelEdgeInfo1) []*ClosedChanSummary { + closedChans ...models.ChannelEdgeInfo) []*ClosedChanSummary { closeSummaries := make([]*ClosedChanSummary, len(closedChans)) for i, closedChan := range closedChans { closeSummaries[i] = &ClosedChanSummary{ - ChanID: closedChan.ChannelID, - Capacity: closedChan.Capacity, + ChanID: closedChan.GetChanID(), + Capacity: closedChan.GetCapacity(), ClosedHeight: blockHeight, - ChanPoint: closedChan.ChannelPoint, + ChanPoint: closedChan.GetChanPoint(), } } @@ -333,7 +334,7 @@ func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, // We ignore initial channel announcements as we'll only send out // updates once the individual edges themselves have been updated. - case *channeldb.ChannelEdgeInfo1: + case models.ChannelEdgeInfo: return nil // Any new ChannelUpdateAnnouncements will generate a corresponding @@ -368,9 +369,9 @@ func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, edgeUpdate := &ChannelEdgeUpdate{ ChanID: m.ChannelID, - ChanPoint: edgeInfo.ChannelPoint, + ChanPoint: edgeInfo.GetChanPoint(), TimeLockDelta: m.TimeLockDelta, - Capacity: edgeInfo.Capacity, + Capacity: edgeInfo.GetCapacity(), MinHTLC: m.MinHTLC, MaxHTLC: m.MaxHTLC, BaseFee: m.FeeBaseMSat, diff --git a/routing/router.go b/routing/router.go index 6fb07c749b..37674e515e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -21,9 +21,9 @@ import ( "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/htlcswitch" - "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" @@ -135,13 +135,12 @@ type ChannelGraphSource interface { // AddEdge is used to add edge/channel to the topology of the router, // after all information about channel will be gathered this // edge/channel might be used in construction of payment path. - AddEdge(edge *channeldb.ChannelEdgeInfo1, - op ...batch.SchedulerOption) error + AddEdge(edge models.ChannelEdgeInfo, op ...batch.SchedulerOption) error // AddProof updates the channel edge info with proof which is needed to // properly announce the edge to the rest of the network. AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof1) error + proof models.ChannelAuthProof) error // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. @@ -176,7 +175,7 @@ type ChannelGraphSource interface { // emanating from the "source" node which is the center of the // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, - c *channeldb.ChannelEdgeInfo1, + c models.ChannelEdgeInfo, e *channeldb.ChannelEdgePolicy) error) error // CurrentBlockHeight returns the block height from POV of the router @@ -185,7 +184,7 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) // FetchLightningNode attempts to look up a target node by its identity @@ -897,18 +896,22 @@ func (r *ChannelRouter) pruneZombieChans() error { log.Infof("Examining channel graph for zombie channels") // A helper method to detect if the channel belongs to this node - isSelfChannelEdge := func(info *channeldb.ChannelEdgeInfo1) bool { - return info.NodeKey1Bytes == r.selfNode.PubKeyBytes || - info.NodeKey2Bytes == r.selfNode.PubKeyBytes + isSelfChannelEdge := func(info models.ChannelEdgeInfo) bool { + return info.Node1Bytes() == r.selfNode.PubKeyBytes || + info.Node2Bytes() == r.selfNode.PubKeyBytes } // First, we'll collect all the channels which are eligible for garbage // collection due to being zombies. - filterPruneChans := func(info *channeldb.ChannelEdgeInfo1, + filterPruneChans := func(info models.ChannelEdgeInfo, e1, e2 *channeldb.ChannelEdgePolicy) error { - // Exit early in case this channel is already marked to be pruned - if _, markedToPrune := chansToPrune[info.ChannelID]; markedToPrune { + chanID := info.GetChanID() + + // Exit early in case this channel is already marked to be + // pruned. + _, markedToPrune := chansToPrune[chanID] + if markedToPrune { return nil } @@ -927,11 +930,11 @@ func (r *ChannelRouter) pruneZombieChans() error { if e1Zombie { log.Tracef("Node1 pubkey=%x of chan_id=%v is zombie", - info.NodeKey1Bytes, info.ChannelID) + info.Node1Bytes(), chanID) } if e2Zombie { log.Tracef("Node2 pubkey=%x of chan_id=%v is zombie", - info.NodeKey2Bytes, info.ChannelID) + info.Node2Bytes(), chanID) } // If we're using strict zombie pruning, then a channel is only @@ -956,10 +959,10 @@ func (r *ChannelRouter) pruneZombieChans() error { } log.Debugf("ChannelID(%v) is a zombie, collecting to prune", - info.ChannelID) + chanID) // TODO(roasbeef): add ability to delete single directional edge - chansToPrune[info.ChannelID] = struct{}{} + chansToPrune[chanID] = struct{}{} return nil } @@ -983,7 +986,8 @@ func (r *ChannelRouter) pruneZombieChans() error { // Ensuring we won't prune our own channel from the graph. for _, disabledEdge := range disabledEdges { if !isSelfChannelEdge(disabledEdge.Info) { - chansToPrune[disabledEdge.Info.ChannelID] = struct{}{} + chanID := disabledEdge.Info.GetChanID() + chansToPrune[chanID] = struct{}{} } } } @@ -1450,70 +1454,6 @@ func (r *ChannelRouter) addZombieEdge(chanID uint64) error { return nil } -// makeFundingScript is used to make the funding script for both segwit v0 and -// segwit v1 (taproot) channels. -// -// TODO(roasbeef: export and use elsewhere? -func makeFundingScript(bitcoinKey1, bitcoinKey2 []byte, - chanFeatures []byte) ([]byte, error) { - - legacyFundingScript := func() ([]byte, error) { - witnessScript, err := input.GenMultiSigScript( - bitcoinKey1, bitcoinKey2, - ) - if err != nil { - return nil, err - } - pkScript, err := input.WitnessScriptHash(witnessScript) - if err != nil { - return nil, err - } - - return pkScript, nil - } - - if len(chanFeatures) == 0 { - return legacyFundingScript() - } - - // In order to make the correct funding script, we'll need to parse the - // chanFeatures bytes into a feature vector we can interact with. - rawFeatures := lnwire.NewRawFeatureVector() - err := rawFeatures.Decode(bytes.NewReader(chanFeatures)) - if err != nil { - return nil, fmt.Errorf("unable to parse chan feature "+ - "bits: %w", err) - } - - chanFeatureBits := lnwire.NewFeatureVector( - rawFeatures, lnwire.Features, - ) - if chanFeatureBits.HasFeature( - lnwire.SimpleTaprootChannelsOptionalStaging, - ) { - - pubKey1, err := btcec.ParsePubKey(bitcoinKey1) - if err != nil { - return nil, err - } - pubKey2, err := btcec.ParsePubKey(bitcoinKey2) - if err != nil { - return nil, err - } - - fundingScript, _, err := input.GenTaprootFundingScript( - pubKey1, pubKey2, 0, - ) - if err != nil { - return nil, err - } - - return fundingScript, nil - } - - return legacyFundingScript() -} - // processUpdate processes a new relate authenticated channel/edge, node or // channel/edge update network update. If the update didn't affect the internal // state of the draft due to either being out of date, invalid, or redundant, @@ -1539,14 +1479,19 @@ func (r *ChannelRouter) processUpdate(msg interface{}, log.Tracef("Updated vertex data for node=%x", msg.PubKeyBytes) r.stats.incNumNodeUpdates() - case *channeldb.ChannelEdgeInfo1: - log.Debugf("Received ChannelEdgeInfo1 for channel %v", - msg.ChannelID) + case models.ChannelEdgeInfo: + var ( + chanID = msg.GetChanID() + node1Bytes = msg.Node1Bytes() + node2Bytes = msg.Node2Bytes() + ) + + log.Debugf("Received ChannelEdgeInfo for channel %v", chanID) // Prior to processing the announcement we first check if we // already know of this channel, if so, then we can exit early. _, _, exists, isZombie, err := r.cfg.Graph.HasChannelEdge( - msg.ChannelID, + chanID, ) if err != nil && err != channeldb.ErrGraphNoEdgesFound { return errors.Errorf("unable to check for edge "+ @@ -1554,11 +1499,11 @@ func (r *ChannelRouter) processUpdate(msg interface{}, } if isZombie { return newErrf(ErrIgnored, "ignoring msg for zombie "+ - "chan_id=%v", msg.ChannelID) + "chan_id=%v", chanID) } if exists { return newErrf(ErrIgnored, "ignoring msg for known "+ - "chan_id=%v", msg.ChannelID) + "chan_id=%v", chanID) } // If AssumeChannelValid is present, then we are unable to @@ -1568,15 +1513,15 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // skip validation as it will not map to a legitimate tx. This // is not a DoS vector as only we can add an alias // ChannelAnnouncement1 from the gossiper. - scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) + scid := lnwire.NewShortChanIDFromInt(chanID) if r.cfg.AssumeChannelValid || r.cfg.IsAlias(scid) { - if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { + err := r.cfg.Graph.AddChannelEdge(msg, op...) + if err != nil { return fmt.Errorf("unable to add edge: %v", err) } log.Tracef("New channel discovered! Link "+ "connects %x and %x with ChannelID(%v)", - msg.NodeKey1Bytes, msg.NodeKey2Bytes, - msg.ChannelID) + node1Bytes, node2Bytes, chanID) r.stats.incNumEdgesDiscovered() break @@ -1585,7 +1530,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // Before we can add the channel to the channel graph, we need // to obtain the full funding outpoint that's encoded within // the channel ID. - channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + channelID := lnwire.NewShortChanIDFromInt(chanID) fundingTx, err := r.fetchFundingTx(&channelID) if err != nil { // In order to ensure we don't erroneously mark a @@ -1608,7 +1553,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // zombie so we don't continue to request it. // We use the "zero key" for both node pubkeys // so this edge can't be resurrected. - zErr := r.addZombieEdge(msg.ChannelID) + zErr := r.addZombieEdge(chanID) if zErr != nil { return zErr } @@ -1623,10 +1568,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // Recreate witness output to be sure that declared in channel // edge bitcoin keys and channel value corresponds to the // reality. - fundingPkScript, err := makeFundingScript( - msg.BitcoinKey1Bytes[:], msg.BitcoinKey2Bytes[:], - msg.Features, - ) + fundingPkScript, err := msg.FundingScript() if err != nil { return err } @@ -1645,7 +1587,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, if err != nil { // Mark the edge as a zombie so we won't try to // re-validate it on start up. - if err := r.addZombieEdge(msg.ChannelID); err != nil { + if err := r.addZombieEdge(chanID); err != nil { return err } @@ -1662,7 +1604,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, ) if err != nil { if errors.Is(err, btcwallet.ErrOutputSpent) { - zErr := r.addZombieEdge(msg.ChannelID) + zErr := r.addZombieEdge(chanID) if zErr != nil { return zErr } @@ -1670,22 +1612,28 @@ func (r *ChannelRouter) processUpdate(msg interface{}, return newErrf(ErrChannelSpent, "unable to fetch utxo "+ "for chan_id=%v, chan_point=%v: %v", - msg.ChannelID, fundingPoint, err) + chanID, fundingPoint, err) } // TODO(roasbeef): this is a hack, needs to be removed // after commitment fees are dynamic. - msg.Capacity = btcutil.Amount(chanUtxo.Value) - msg.ChannelPoint = *fundingPoint + switch m := msg.(type) { + case *channeldb.ChannelEdgeInfo1: + m.Capacity = btcutil.Amount(chanUtxo.Value) + m.ChannelPoint = *fundingPoint + default: + return errors.Errorf("unhandled implementation of "+ + "ChannelEdgeInfo: %T", msg) + } + if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { return errors.Errorf("unable to add edge: %v", err) } log.Debugf("New channel discovered! Link "+ "connects %x and %x with ChannelPoint(%v): "+ - "chan_id=%v, capacity=%v", - msg.NodeKey1Bytes, msg.NodeKey2Bytes, - fundingPoint, msg.ChannelID, msg.Capacity) + "chan_id=%v, capacity=%v", node1Bytes, node2Bytes, + fundingPoint, chanID, msg.GetCapacity()) r.stats.incNumEdgesDiscovered() // As a new edge has been added to the channel graph, we'll @@ -2662,7 +2610,7 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate) bool { return false } - err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) + err = ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg) if err != nil { log.Errorf("Unable to validate channel update: %v", err) return false @@ -2720,7 +2668,7 @@ func (r *ChannelRouter) AddNode(node *channeldb.LightningNode, // in construction of payment path. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) AddEdge(edge *channeldb.ChannelEdgeInfo1, +func (r *ChannelRouter) AddEdge(edge models.ChannelEdgeInfo, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -2787,8 +2735,7 @@ func (r *ChannelRouter) SyncedHeight() uint32 { // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( - *channeldb.ChannelEdgeInfo1, - *channeldb.ChannelEdgePolicy, + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { return r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) @@ -2822,10 +2769,10 @@ func (r *ChannelRouter) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, - *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy) error) error { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy) error) error { return r.selfNode.ForEachChannel(r.cfg.Graph.DB(), nil, - func(_ kvdb.Backend, tx kvdb.RTx, c *channeldb.ChannelEdgeInfo1, + func(_ kvdb.Backend, tx kvdb.RTx, c models.ChannelEdgeInfo, e, _ *channeldb.ChannelEdgePolicy) error { if e == nil { @@ -2842,14 +2789,18 @@ func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) AddProof(chanID lnwire.ShortChannelID, - proof *channeldb.ChannelAuthProof1) error { + proof models.ChannelAuthProof) error { info, _, _, err := r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) if err != nil { return err } - info.AuthProof = proof + err = info.SetAuthProof(proof) + if err != nil { + return err + } + return r.cfg.Graph.UpdateChannelEdge(info) } diff --git a/routing/router_test.go b/routing/router_test.go index 614f347767..8bcd79627c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1262,9 +1262,7 @@ func TestAddProof(t *testing.T) { info, _, _, err := ctx.router.GetChannelByID(*chanID) require.NoError(t, err, "unable to get channel") - if info.AuthProof == nil { - t.Fatal("proof have been updated") - } + require.NotNil(t, info.GetAuthProof()) } // TestIgnoreNodeAnnouncement tests that adding a node to the router that is diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index ef3109d25b..24423402aa 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -126,9 +127,8 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { v.nodeAnnDependencies[msg.Node1KeyBytes()] = signals v.nodeAnnDependencies[msg.Node2KeyBytes()] = signals } - case *channeldb.ChannelEdgeInfo1: - - shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + case models.ChannelEdgeInfo: + shortID := lnwire.NewShortChanIDFromInt(msg.GetChanID()) if _, ok := v.chanAnnFinSignal[shortID]; !ok { signals := &validationSignals{ allow: make(chan struct{}), @@ -138,8 +138,8 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { v.chanAnnFinSignal[shortID] = signals v.chanEdgeDependencies[shortID] = signals - v.nodeAnnDependencies[route.Vertex(msg.NodeKey1Bytes)] = signals - v.nodeAnnDependencies[route.Vertex(msg.NodeKey2Bytes)] = signals + v.nodeAnnDependencies[msg.Node1Bytes()] = signals + v.nodeAnnDependencies[msg.Node2Bytes()] = signals } // These other types don't have any dependants, so no further @@ -218,7 +218,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { // return directly. case *lnwire.AnnounceSignatures: // TODO(roasbeef): need to wait on chan ann? - case *channeldb.ChannelEdgeInfo1: + case models.ChannelEdgeInfo: case lnwire.ChannelAnnouncement: } @@ -264,8 +264,8 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { // If we've just finished executing a ChannelAnnouncement1, then we'll // close out the signal, and remove the signal from the map of active // ones. This will allow/deny any dependent jobs to continue execution. - case *channeldb.ChannelEdgeInfo1: - shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + case models.ChannelEdgeInfo: + shortID := lnwire.NewShortChanIDFromInt(msg.GetChanID()) finSignals, ok := v.chanAnnFinSignal[shortID] if ok { if allow { diff --git a/rpcserver.go b/rpcserver.go index 3aed256fc5..a4461c6c90 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -38,6 +38,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/discovery" @@ -661,7 +662,8 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, if err != nil { return 0, err } - return info.Capacity, nil + + return info.GetCapacity(), nil }, FetchAmountPairCapacity: func(nodeFrom, nodeTo route.Vertex, amount lnwire.MilliSatoshi) (btcutil.Amount, error) { @@ -698,7 +700,7 @@ func (r *rpcServer) addDeps(s *server, macService *macaroons.Service, chanID, err) } - return info.NodeKey1Bytes, info.NodeKey2Bytes, nil + return info.Node1Bytes(), info.Node2Bytes(), nil }, FindRoute: s.chanRouter.FindRoute, MissionControl: s.missionControl, @@ -5921,18 +5923,22 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // Next, for each active channel we know of within the graph, create a // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. - err = graph.ForEachChannel(func(edgeInfo *channeldb.ChannelEdgeInfo1, + err = graph.ForEachChannel(func(edgeInfo models.ChannelEdgeInfo, c1, c2 *channeldb.ChannelEdgePolicy) error { // Do not include unannounced channels unless specifically // requested. Unannounced channels include both private channels as // well as public channels whose authentication proof were not // confirmed yet, hence were not announced. - if !includeUnannounced && edgeInfo.AuthProof == nil { + if !includeUnannounced && edgeInfo.GetAuthProof() == nil { return nil } - edge := marshalDbEdge(edgeInfo, c1, c2) + edge, err := marshalDBEdge(edgeInfo, c1, c2) + if err != nil { + return err + } + resp.Edges = append(resp.Edges, edge) return nil @@ -5977,8 +5983,8 @@ func marshalExtraOpaqueData(data []byte) map[uint64][]byte { return records } -func marshalDbEdge(edgeInfo *channeldb.ChannelEdgeInfo1, - c1, c2 *channeldb.ChannelEdgePolicy) *lnrpc.ChannelEdge { +func marshalDBEdge(edgeInfo models.ChannelEdgeInfo, + c1, c2 *channeldb.ChannelEdgePolicy) (*lnrpc.ChannelEdge, error) { // Make sure the policies match the node they belong to. c1 should point // to the policy for NodeKey1, and c2 for NodeKey2. @@ -5996,28 +6002,41 @@ func marshalDbEdge(edgeInfo *channeldb.ChannelEdgeInfo1, lastUpdate = c2.LastUpdate.Unix() } - customRecords := marshalExtraOpaqueData(edgeInfo.ExtraOpaqueData) + var edge *lnrpc.ChannelEdge - edge := &lnrpc.ChannelEdge{ - ChannelId: edgeInfo.ChannelID, - ChanPoint: edgeInfo.ChannelPoint.String(), - // TODO(roasbeef): update should be on edge info itself - LastUpdate: uint32(lastUpdate), - Node1Pub: hex.EncodeToString(edgeInfo.NodeKey1Bytes[:]), - Node2Pub: hex.EncodeToString(edgeInfo.NodeKey2Bytes[:]), - Capacity: int64(edgeInfo.Capacity), - CustomRecords: customRecords, - } + switch info := edgeInfo.(type) { + case *channeldb.ChannelEdgeInfo1: + customRecords := marshalExtraOpaqueData(info.ExtraOpaqueData) - if c1 != nil { - edge.Node1Policy = marshalDBRoutingPolicy(c1) - } + edge = &lnrpc.ChannelEdge{ + ChannelId: info.ChannelID, + ChanPoint: info.ChannelPoint.String(), + // TODO(roasbeef): update should be on edge info itself + LastUpdate: uint32(lastUpdate), + Node1Pub: hex.EncodeToString( + info.NodeKey1Bytes[:], + ), + Node2Pub: hex.EncodeToString( + info.NodeKey2Bytes[:], + ), + Capacity: int64(edgeInfo.GetCapacity()), + CustomRecords: customRecords, + } - if c2 != nil { - edge.Node2Policy = marshalDBRoutingPolicy(c2) + if c1 != nil { + edge.Node1Policy = marshalDBRoutingPolicy(c1) + } + + if c2 != nil { + edge.Node2Policy = marshalDBRoutingPolicy(c2) + } + + default: + return nil, fmt.Errorf("unhandled implementation of "+ + "channeldb.ChannelEdgeInfo: %T", edgeInfo) } - return edge + return edge, nil } func marshalDBRoutingPolicy( @@ -6113,7 +6132,10 @@ func (r *rpcServer) GetChanInfo(ctx context.Context, // Convert the database's edge format into the network/RPC edge format // which couples the edge itself along with the directional node // routing policies of each node involved within the channel. - channelEdge := marshalDbEdge(edgeInfo, edge1, edge2) + channelEdge, err := marshalDBEdge(edgeInfo, edge1, edge2) + if err != nil { + return nil, err + } return channelEdge, nil } @@ -6152,24 +6174,28 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, ) if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, - _ kvdb.RTx, edge *channeldb.ChannelEdgeInfo1, + _ kvdb.RTx, edge models.ChannelEdgeInfo, c1, c2 *channeldb.ChannelEdgePolicy) error { numChannels++ - totalCapacity += edge.Capacity + totalCapacity += edge.GetCapacity() // Only populate the node's channels if the user requested them. if in.IncludeChannels { // Do not include unannounced channels - private // channels or public channels whose authentication // proof were not confirmed yet. - if edge.AuthProof == nil { + if edge.GetAuthProof() == nil { return nil } // Convert the database's edge format into the // network/RPC edge format. - channelEdge := marshalDbEdge(edge, c1, c2) + channelEdge, err := marshalDBEdge(edge, c1, c2) + if err != nil { + return err + } + channels = append(channels, channelEdge) } @@ -6765,27 +6791,30 @@ func (r *rpcServer) FeeReport(ctx context.Context, var feeReports []*lnrpc.ChannelFeeReport err = selfNode.ForEachChannel(channelGraph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, - chanInfo *channeldb.ChannelEdgeInfo1, + chanInfo models.ChannelEdgeInfo, edgePolicy, _ *channeldb.ChannelEdgePolicy) error { - // Self node should always have policies for its channels. + // Self node should always have policies for its + // channels. if edgePolicy == nil { - return fmt.Errorf("no policy for outgoing channel %v ", - chanInfo.ChannelID) + return fmt.Errorf("no policy for outgoing "+ + "channel %v ", chanInfo.GetChanID()) } - // We'll compute the effective fee rate by converting from a - // fixed point fee rate to a floating point fee rate. The fee - // rate field in the database the amount of mSAT charged per - // 1mil mSAT sent, so will divide by this to get the proper fee - // rate. - feeRateFixedPoint := edgePolicy.FeeProportionalMillionths + // We'll compute the effective fee rate by converting + // from a fixed point fee rate to a floating point fee + // rate. The fee rate field in the database the amount + // of mSAT charged per 1mil mSAT sent, so will divide + // by this to get the proper fee rate. + feeRateFixedPoint := edgePolicy. + FeeProportionalMillionths feeRate := float64(feeRateFixedPoint) / feeBase - // TODO(roasbeef): also add stats for revenue for each channel + // TODO(roasbeef): also add stats for revenue for each + // channel. feeReports = append(feeReports, &lnrpc.ChannelFeeReport{ - ChanId: chanInfo.ChannelID, - ChannelPoint: chanInfo.ChannelPoint.String(), + ChanId: chanInfo.GetChanID(), + ChannelPoint: chanInfo.GetChanPoint().String(), BaseFeeMsat: int64(edgePolicy.FeeBaseMSat), FeePerMil: int64(feeRateFixedPoint), FeeRate: feeRate, diff --git a/server.go b/server.go index 7ac07cdab8..5ef7886fbf 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ import ( "github.com/lightningnetwork/lnd/chanbackup" "github.com/lightningnetwork/lnd/chanfitness" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/contractcourt" @@ -1254,7 +1255,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, copy(ourKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) var ourPolicy *channeldb.ChannelEdgePolicy - if info != nil && info.NodeKey1Bytes == ourKey { + if info != nil && info.Node1Bytes() == ourKey { ourPolicy = e1 } else { ourPolicy = e2 @@ -3096,7 +3097,7 @@ func (s *server) establishPersistentConnections() error { selfPub := s.identityECDH.PubKey().SerializeCompressed() err = sourceNode.ForEachChannel(s.graphDB.DB(), nil, func( db kvdb.Backend, tx kvdb.RTx, - chanInfo *channeldb.ChannelEdgeInfo1, + chanInfo models.ChannelEdgeInfo, policy, _ *channeldb.ChannelEdgePolicy) error { // If the remote party has announced the channel to us, but we @@ -3104,15 +3105,17 @@ func (s *server) establishPersistentConnections() error { // need this to connect to the peer, so we'll log it and move on. if policy == nil { srvrLog.Warnf("No channel policy found for "+ - "ChannelPoint(%v): ", chanInfo.ChannelPoint) + "ChannelPoint(%v): ", chanInfo.GetChanPoint()) } // We'll now fetch the peer opposite from us within this // channel so we can queue up a direct connection to them. - channelPeer, err := chanInfo.FetchOtherNode(db, tx, selfPub) + channelPeer, err := channeldb.FetchOtherNode( + db, tx, chanInfo, selfPub, + ) if err != nil { return fmt.Errorf("unable to fetch channel peer for "+ - "ChannelPoint(%v): %v", chanInfo.ChannelPoint, + "ChannelPoint(%v): %v", chanInfo.GetChanPoint(), err) } From 7763e1a6443dfaf5cf8025cfa8cab186f441415d Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 14:45:07 +0200 Subject: [PATCH 21/33] channeldb: prepare for reading new types of ChannelEdgeInfo --- channeldb/graph.go | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index f9e749ae86..d0a440d313 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -1,6 +1,7 @@ package channeldb import ( + "bufio" "bytes" "encoding/binary" "errors" @@ -163,6 +164,13 @@ const ( // feeRateParts is the total number of parts used to express fee rates. feeRateParts = 1e6 + + // chanEdgeNewEncodingPrefix is a byte used in the channel edge encoding + // to signal that the new style encoding which is prefixed with a type + // byte is being used instead of the legacy encoding which would start + // with either 0x02 or 0x03 due to the fact that the encoding would + // start with a node's compressed public key. + chanEdgeNewEncodingPrefix = 0xff ) // ChannelGraph is a persistent, on-disk graph representation of the Lightning @@ -4555,10 +4563,6 @@ func serializeChanEdgeInfo1(w io.Writer, edgeInfo *ChannelEdgeInfo1, if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) } - err = wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) - if err != nil { - return err - } return wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) } @@ -4577,7 +4581,20 @@ func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, } func deserializeChanEdgeInfo(reader io.Reader) (models.ChannelEdgeInfo, error) { - return deserializeChanEdgeInfo1(reader) + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgeNewEncodingPrefix { + return deserializeChanEdgeInfo1(r) + } + + return nil, fmt.Errorf("unknown channel edge encoding") } func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { From 2ffd49ecfdb99f632e03b7d855e97eb23483c4a4 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 16:09:12 +0200 Subject: [PATCH 22/33] multi: rename lnwire.ChannelUpdate to lnwire.ChannelUpdate1 --- channeldb/models/channel.go | 2 +- discovery/chan_series.go | 10 ++--- discovery/gossiper.go | 62 +++++++++++++-------------- discovery/gossiper_test.go | 62 +++++++++++++-------------- discovery/message_store.go | 10 ++--- discovery/message_store_test.go | 10 ++--- discovery/reliable_sender_test.go | 2 +- discovery/syncer.go | 6 +-- discovery/syncer_test.go | 18 ++++---- funding/manager.go | 30 ++++++------- funding/manager_test.go | 14 +++--- htlcswitch/interfaces.go | 2 +- htlcswitch/link.go | 26 +++++------ htlcswitch/link_test.go | 6 +-- htlcswitch/mock.go | 10 ++--- htlcswitch/switch.go | 20 ++++----- htlcswitch/switch_test.go | 10 ++--- htlcswitch/test_utils.go | 4 +- itest/lnd_channel_policy_test.go | 4 +- itest/lnd_zero_conf_test.go | 2 +- lnrpc/routerrpc/router_backend.go | 2 +- lnwire/channel_update.go | 26 +++++------ lnwire/lnwire_test.go | 6 +-- lnwire/message.go | 4 +- lnwire/message_test.go | 6 +-- lnwire/onion_error.go | 44 +++++++++---------- lnwire/onion_error_test.go | 6 +-- netann/chan_status_manager.go | 10 ++--- netann/chan_status_manager_test.go | 6 +-- netann/channel_announcement.go | 6 +-- netann/channel_update.go | 28 ++++++------ netann/channel_update_test.go | 4 +- netann/sign.go | 2 +- peer/brontide.go | 10 ++--- peer/test_utils.go | 2 +- routing/ann_validation.go | 6 +-- routing/missioncontrol_test.go | 6 +-- routing/mock_test.go | 4 +- routing/payment_session.go | 4 +- routing/payment_session_test.go | 2 +- routing/result_interpretation_test.go | 4 +- routing/router.go | 8 ++-- routing/router_test.go | 16 +++---- routing/validation_barrier.go | 10 ++--- routing/validation_barrier_test.go | 4 +- server.go | 14 +++--- 46 files changed, 275 insertions(+), 275 deletions(-) diff --git a/channeldb/models/channel.go b/channeldb/models/channel.go index 4a65462e72..f3d8b8868c 100644 --- a/channeldb/models/channel.go +++ b/channeldb/models/channel.go @@ -96,7 +96,7 @@ func (k CircuitKey) String() string { // constraints will be consulted in order to ensure that adequate fees are // paid, and our time-lock parameters are respected. In the event that an // incoming HTLC violates any of these constraints, it is to be _rejected_ with -// the error possibly carrying along a ChannelUpdate message that includes the +// the error possibly carrying along a ChannelUpdate1 message that includes the // latest policy. type ForwardingPolicy struct { // MinHTLCOut is the smallest HTLC that is to be forwarded. diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 891c660e9e..a8cdcd8ce9 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -50,7 +50,7 @@ type ChannelGraphTimeSeries interface { // their updates that match the set of specified short channel ID's. // We'll use this to reply to a QueryShortChanIDs message sent by a // remote peer. The response will contain a unique set of - // ChannelAnnouncements, the latest ChannelUpdate for each of the + // ChannelAnnouncements, the latest ChannelUpdate1 for each of the // announcements, and a unique set of NodeAnnouncements. FetchChanAnns(chain chainhash.Hash, shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) @@ -59,7 +59,7 @@ type ChannelGraphTimeSeries interface { // specified short channel ID. If no channel updates are known for the // channel, then an empty slice will be returned. FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) } // ChanSeries is an implementation of the ChannelGraphTimeSeries @@ -233,7 +233,7 @@ func (c *ChanSeries) FilterChannelRange(chain chainhash.Hash, // FetchChanAnns returns a full set of channel announcements as well as their // updates that match the set of specified short channel ID's. We'll use this // to reply to a QueryShortChanIDs message sent by a remote peer. The response -// will contain a unique set of ChannelAnnouncements, the latest ChannelUpdate +// will contain a unique set of ChannelAnnouncements, the latest ChannelUpdate1 // for each of the announcements, and a unique set of NodeAnnouncements. // // NOTE: This is part of the ChannelGraphTimeSeries interface. @@ -318,7 +318,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { chanInfo, e1, e2, err := c.graph.FetchChannelEdgesByID( shortChanID.ToUint64(), @@ -327,7 +327,7 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, return nil, err } - chanUpdates := make([]*lnwire.ChannelUpdate, 0, 2) + chanUpdates := make([]*lnwire.ChannelUpdate1, 0, 2) if e1 != nil { chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index de0ef16be9..46f3c61e72 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -141,7 +141,7 @@ type networkMsg struct { } // chanPolicyUpdateRequest is a request that is sent to the server when a caller -// wishes to update a particular set of channels. New ChannelUpdate messages +// wishes to update a particular set of channels. New ChannelUpdate1 messages // will be crafted to be sent out during the next broadcast epoch and the fee // updates committed to the lower layer. type chanPolicyUpdateRequest struct { @@ -312,7 +312,7 @@ type Config struct { // SignAliasUpdate is used to re-sign a channel update using the // remote's alias if the option-scid-alias feature bit was negotiated. - SignAliasUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) // FindBaseByAlias finds the SCID stored in the graph by an alias SCID. @@ -890,7 +890,7 @@ func (d *AuthenticatedGossiper) ProcessLocalAnnouncement(msg lnwire.Message, return nMsg.err } -// channelUpdateID is a unique identifier for ChannelUpdate messages, as +// channelUpdateID is a unique identifier for ChannelUpdate1 messages, as // channel updates can be identified by the (ShortChannelID, ChannelFlags) // tuple. type channelUpdateID struct { @@ -1010,7 +1010,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // Channel updates are identified by the (short channel id, // channelflags) tuple. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: sender := route.NewVertex(message.source) deDupKey := channelUpdateID{ msg.ShortChannelID, @@ -1022,7 +1022,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { if ok { // If we already have seen this message, record its // timestamp. - oldTimestamp = mws.msg.(*lnwire.ChannelUpdate).Timestamp + oldTimestamp = mws.msg.(*lnwire.ChannelUpdate1).Timestamp } // If we already had this message with a strictly newer @@ -1359,7 +1359,7 @@ func (d *AuthenticatedGossiper) networkHandler() { select { // A new policy update has arrived. We'll commit it to the // sub-systems below us, then craft, sign, and broadcast a new - // ChannelUpdate for the set of affected clients. + // ChannelUpdate1 for the set of affected clients. case policyUpdate := <-d.chanPolicyUpdates: log.Tracef("Received channel %d policy update requests", len(policyUpdate.edgesToUpdate)) @@ -1557,7 +1557,7 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, var scid uint64 switch m := msg.(type) { - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: scid = m.ShortChannelID.ToUint64() case lnwire.ChannelAnnouncement: @@ -1608,7 +1608,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // announcement below. havePublicChannels = true - // If this edge has a ChannelUpdate that was created before the + // If this edge has a ChannelUpdate1 that was created before the // introduction of the MaxHTLC field, then we'll update this // edge to propagate this information in the network. if !edge.MessageFlags.HasMaxHtlc() { @@ -1648,7 +1648,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { var signedUpdates []lnwire.Message for _, chanToUpdate := range edgesToUpdate { // Re-sign and update the channel on disk and retrieve our - // ChannelUpdate to broadcast. + // ChannelUpdate1 to broadcast. chanAnn, chanUpdate, err := d.updateChannel( chanToUpdate.info, chanToUpdate.edge, ) @@ -1724,7 +1724,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( for _, edgeInfo := range edgesToUpdate { // Now that we've collected all the channels we need to update, // we'll re-sign and update the backing ChannelGraphSource, and - // retrieve our ChannelUpdate to broadcast. + // retrieve our ChannelUpdate1 to broadcast. _, chanUpdate, err := d.updateChannel( edgeInfo.Info, edgeInfo.Edge, ) @@ -2009,7 +2009,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // A new authenticated channel edge update has arrived. This indicates // that the directional information for an already known channel has // been updated. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return d.handleChanUpdate(nMsg, msg, schedulerOp) // A new signature announcement has been received. This indicates @@ -2028,7 +2028,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // processZombieUpdate determines whether the provided channel update should // resurrect a given zombie edge. func (d *AuthenticatedGossiper) processZombieUpdate( - chanInfo models.ChannelEdgeInfo, msg *lnwire.ChannelUpdate) error { + chanInfo models.ChannelEdgeInfo, msg *lnwire.ChannelUpdate1) error { // The least-significant bit in the flag on the channel update tells us // which edge is being updated. @@ -2113,7 +2113,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // can safely delete the local proof from the database. return chanInfo.GetAuthProof() != nil - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: _, p1, p2, err := d.cfg.Router.GetChannelByID(msg.ShortChannelID) // If the channel cannot be found, it is most likely a leftover @@ -2158,7 +2158,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, edge *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { // Parse the unsigned edge into a channel update. chanUpdate := netann.UnsignedChannelUpdateFromEdge( @@ -2267,7 +2267,7 @@ func (d *AuthenticatedGossiper) SyncManager() *SyncManager { // IsKeepAliveUpdate determines whether this channel update is considered a // keep-alive update based on the previous channel update processed for the same // direction. -func IsKeepAliveUpdate(update *lnwire.ChannelUpdate, +func IsKeepAliveUpdate(update *lnwire.ChannelUpdate1, prev *channeldb.ChannelEdgePolicy) bool { // Both updates should be from the same direction. @@ -2586,7 +2586,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, channelUpdates = append(channelUpdates, chanMsgs.msgs...) } - // Launch a new goroutine to handle each ChannelUpdate, this is to + // Launch a new goroutine to handle each ChannelUpdate1, this is to // ensure we don't block here, as we can handle only one announcement // at a time. for _, cu := range channelUpdates { @@ -2595,9 +2595,9 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, continue } - // Mark the ChannelUpdate as processed. This ensures that a + // Mark the ChannelUpdate1 as processed. This ensures that a // subsequent announcement in the option-scid-alias case does - // not re-use an old ChannelUpdate. + // not re-use an old ChannelUpdate1. cu.processed = true d.wg.Add(1) @@ -2608,8 +2608,8 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // Reprocess the message, making sure we return an // error to the original caller in case the gossiper // shuts down. - case *lnwire.ChannelUpdate: - log.Debugf("Reprocessing ChannelUpdate for "+ + case *lnwire.ChannelUpdate1: + log.Debugf("Reprocessing ChannelUpdate1 for "+ "shortChanID=%v", msg.ShortChannelID.ToUint64()) @@ -2620,7 +2620,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } // We don't expect any other message type than - // ChannelUpdate to be in this cache. + // ChannelUpdate1 to be in this cache. default: log.Errorf("Unsupported message type found "+ "among ChannelUpdates: %T", msg) @@ -2652,16 +2652,16 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // handleChanUpdate processes a new channel update. func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, - upd *lnwire.ChannelUpdate, + upd *lnwire.ChannelUpdate1, ops []batch.SchedulerOption) ([]networkMsg, bool) { - log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", + log.Debugf("Processing ChannelUpdate1: peer=%v, short_chan_id=%v, ", nMsg.peer, upd.ShortChannelID.ToUint64()) // We'll ignore any channel updates that target any chain other than // the set of chains we know of. if !bytes.Equal(upd.ChainHash[:], d.cfg.ChainHash[:]) { - err := fmt.Errorf("ignoring ChannelUpdate from chain=%v, "+ + err := fmt.Errorf("ignoring ChannelUpdate1 from chain=%v, "+ "gossiper on chain=%v", upd.ChainHash, d.cfg.ChainHash) log.Errorf(err.Error()) @@ -2761,7 +2761,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, case channeldb.ErrGraphNoEdgesFound: fallthrough case channeldb.ErrEdgeNotFound: - // If the edge corresponding to this ChannelUpdate was not + // If the edge corresponding to this ChannelUpdate1 was not // found in the graph, this might be a channel in the process // of being opened, and we haven't processed our own // ChannelAnnouncement yet, hence it is not found in the @@ -2801,13 +2801,13 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, }) } - log.Debugf("Got ChannelUpdate for edge not found in graph"+ + log.Debugf("Got ChannelUpdate1 for edge not found in graph"+ "(shortChanID=%v), saving for reprocessing later", shortChanID) // NOTE: We don't return anything on the error channel for this // message, as we expect that will be done when this - // ChannelUpdate is later reprocessed. + // ChannelUpdate1 is later reprocessed. return nil, false default: @@ -2842,7 +2842,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, edgeToUpdate = e2 } - log.Debugf("Validating ChannelUpdate: channel=%v, from node=%x, has "+ + log.Debugf("Validating ChannelUpdate1: channel=%v, from node=%x, has "+ "edge=%v", chanInfo.GetChanID(), pubKey.SerializeCompressed(), edgeToUpdate != nil) @@ -2915,7 +2915,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } // We'll use chanInfo.ChannelID rather than the peer-supplied - // ShortChannelID in the ChannelUpdate to avoid the router having to + // ShortChannelID in the ChannelUpdate1 to avoid the router having to // lookup the stored SCID. If we're sending the update, we'll always // use the SCID stored in the database rather than a potentially // different alias. This might mean that SigBytes is incorrect as it @@ -2961,7 +2961,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, return nil, false } - // If this is a local ChannelUpdate without an AuthProof, it means it + // If this is a local ChannelUpdate1 without an AuthProof, it means it // is an update to a channel that is not (yet) supposed to be announced // to the greater network. However, our channel counterparty will need // to be given the update, so we'll try sending the update directly to @@ -3036,7 +3036,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, nMsg.err <- nil - log.Debugf("Processed ChannelUpdate: peer=%v, short_chan_id=%v, "+ + log.Debugf("Processed ChannelUpdate1: peer=%v, short_chan_id=%v, "+ "timestamp=%v", nMsg.peer, upd.ShortChannelID.ToUint64(), timestamp) return announcements, true diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 8f964a3749..8b3a81141b 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -479,8 +479,8 @@ type annBatch struct { chanAnn *lnwire.ChannelAnnouncement1 - chanUpdAnn1 *lnwire.ChannelUpdate - chanUpdAnn2 *lnwire.ChannelUpdate + chanUpdAnn1 *lnwire.ChannelUpdate1 + chanUpdAnn2 *lnwire.ChannelUpdate1 localProofAnn *lnwire.AnnounceSignatures remoteProofAnn *lnwire.AnnounceSignatures @@ -586,12 +586,12 @@ func createNodeAnnouncement(priv *btcec.PrivateKey, func createUpdateAnnouncement(blockHeight uint32, flags lnwire.ChanUpdateChanFlags, nodeKey *btcec.PrivateKey, timestamp uint32, - extraBytes ...[]byte) (*lnwire.ChannelUpdate, error) { + extraBytes ...[]byte) (*lnwire.ChannelUpdate1, error) { var err error htlcMinMsat := lnwire.MilliSatoshi(prand.Int63()) - a := &lnwire.ChannelUpdate{ + a := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -619,7 +619,7 @@ func createUpdateAnnouncement(blockHeight uint32, return a, nil } -func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error { +func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate1) error { signer := mock.SingleSigner{Privkey: nodeKey} sig, err := netann.SignAnnouncement(&signer, testKeyLoc, a) if err != nil { @@ -748,7 +748,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { return false } - signAliasUpdate := func(*lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return nil, nil @@ -1057,7 +1057,7 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { case <-time.After(2 * trickleDelay): } - // The local ChannelUpdate should now be sent directly to the remote peer, + // The local ChannelUpdate1 should now be sent directly to the remote peer, // such that the edge can be used for routing, regardless if this channel // is announced or not (private channel). select { @@ -1261,7 +1261,7 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { case <-time.After(2 * trickleDelay): } - // The local ChannelUpdate should now be sent directly to the remote peer, + // The local ChannelUpdate1 should now be sent directly to the remote peer, // such that the edge can be used for routing, regardless if this channel // is announced or not (private channel). select { @@ -1450,7 +1450,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { return false } - signAliasUpdate := func(*lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return nil, nil @@ -1526,7 +1526,7 @@ out: for { select { case msg := <-sentToPeer: - // Since the ChannelUpdate will also be resent as it is + // Since the ChannelUpdate1 will also be resent as it is // sent reliably, we'll need to filter it out. if _, ok := msg.(*lnwire.AnnounceSignatures); !ok { continue @@ -1827,7 +1827,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { } // Adding the very same announcement shouldn't cause an increase in the - // number of ChannelUpdate announcements stored. + // number of ChannelUpdate1 announcements stored. ua2, err := createUpdateAnnouncement(0, 0, remoteKeyPriv1, timestamp) require.NoError(t, err, "can't create update announcement") announcements.AddMsgs(networkMsg{ @@ -1852,7 +1852,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { t.Fatal("channel update not replaced in batch") } - assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate) { + assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate1) { channelKey := channelUpdateID{ ua3.ShortChannelID, ua3.ChannelFlags, @@ -2371,7 +2371,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { } } -// TestReceiveRemoteChannelUpdateFirst tests that if we receive a ChannelUpdate +// TestReceiveRemoteChannelUpdateFirst tests that if we receive a ChannelUpdate1 // from the remote before we have processed our own ChannelAnnouncement1, it will // be reprocessed later, after our ChannelAnnouncement1. func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { @@ -2399,9 +2399,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { peerChan <- remotePeer } - // Recreate the case where the remote node is sending us its ChannelUpdate + // Recreate the case where the remote node is sending us its ChannelUpdate1 // before we have been able to process our own ChannelAnnouncement1 and - // ChannelUpdate. + // ChannelUpdate1. errRemoteAnn := ctx.gossiper.ProcessRemoteAnnouncement( batch.chanUpdAnn2, remotePeer, ) @@ -2419,7 +2419,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { case <-time.After(2 * trickleDelay): } - // Since the remote ChannelUpdate was added for an edge that + // Since the remote ChannelUpdate1 was added for an edge that // we did not already know about, it should have been added // to the map of premature ChannelUpdates. Check that nothing // was added to the graph. @@ -2469,7 +2469,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { case <-time.After(2 * trickleDelay): } - // The local ChannelUpdate should now be sent directly to the remote peer, + // The local ChannelUpdate1 should now be sent directly to the remote peer, // such that the edge can be used for routing, regardless if this channel // is announced or not (private channel). select { @@ -2479,7 +2479,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Fatal("gossiper did not send channel update to peer") } - // At this point the remote ChannelUpdate we received earlier should + // At this point the remote ChannelUpdate1 we received earlier should // be reprocessed, as we now have the necessary edge entry in the graph. select { case err := <-errRemoteAnn: @@ -2599,7 +2599,7 @@ func TestExtraDataChannelAnnouncementValidation(t *testing.T) { } // TestExtraDataChannelUpdateValidation tests that we're able to properly -// validate a ChannelUpdate that includes opaque bytes that we don't currently +// validate a ChannelUpdate1 that includes opaque bytes that we don't currently // know of. func TestExtraDataChannelUpdateValidation(t *testing.T) { t.Parallel() @@ -2790,7 +2790,7 @@ func TestRetransmit(t *testing.T) { switch msg.(type) { case *lnwire.ChannelAnnouncement1: chanAnn++ - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: chanUpd++ case *lnwire.NodeAnnouncement: nodeAnn++ @@ -2911,7 +2911,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { } // TestOptionalFieldsChannelUpdateValidation tests that we're able to properly -// validate the msg flags and max HTLC field of a ChannelUpdate. +// validate the msg flags and max HTLC field of a ChannelUpdate1. func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { t.Parallel() @@ -3221,7 +3221,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // already been announced. We'll keep track of the old message that is // now stale to use later on. staleChannelUpdate := batch.chanUpdAnn1 - newChannelUpdate := &lnwire.ChannelUpdate{} + newChannelUpdate := &lnwire.ChannelUpdate1{} *newChannelUpdate = *staleChannelUpdate newChannelUpdate.Timestamp++ if err := signUpdate(selfKeyPriv, newChannelUpdate); err != nil { @@ -3265,7 +3265,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { peerChan <- remotePeer // At this point, we should have sent both the AnnounceSignatures and - // stale ChannelUpdate. + // stale ChannelUpdate1. for i := 0; i < 2; i++ { var msg lnwire.Message select { @@ -3275,7 +3275,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { } switch msg := msg.(type) { - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: assertMessage(t, staleChannelUpdate, msg) case *lnwire.AnnounceSignatures: assertMessage(t, batch.localProofAnn, msg) @@ -3385,7 +3385,7 @@ func TestPropagateChanPolicyUpdate(t *testing.T) { sentMsgs := make(chan lnwire.Message, 10) remotePeer := &mockPeer{remoteKey, sentMsgs, ctx.gossiper.quit} - // The forced code path for sending the private ChannelUpdate to the + // The forced code path for sending the private ChannelUpdate1 to the // remote peer will be hit, forcing it to request a notification that // the remote peer is active. We'll ensure that it targets the proper // pubkey, and hand it our mock peer above. @@ -3477,7 +3477,7 @@ out: // being the channel our first private channel. for i := 0; i < numChannels-1; i++ { assertBroadcastMsg(t, ctx, func(msg lnwire.Message) error { - upd, ok := msg.(*lnwire.ChannelUpdate) + upd, ok := msg.(*lnwire.ChannelUpdate1) if !ok { return fmt.Errorf("channel update not "+ "broadcast, instead %T was", msg) @@ -3497,11 +3497,11 @@ out: }) } - // Finally the ChannelUpdate should have been sent directly to the + // Finally the ChannelUpdate1 should have been sent directly to the // remote peer via the reliable sender. select { case msg := <-sentMsgs: - upd, ok := msg.(*lnwire.ChannelUpdate) + upd, ok := msg.(*lnwire.ChannelUpdate1) if !ok { t.Fatalf("channel update not "+ "broadcast, instead %T was", msg) @@ -3519,13 +3519,13 @@ out: t.Fatalf("message not sent directly to peer") } - // At this point, no other ChannelUpdate messages should be broadcast + // At this point, no other ChannelUpdate1 messages should be broadcast // as we sent the two public ones to the network, and the private one // was sent directly to the peer. for { select { case msg := <-ctx.broadcastedMessage: - if upd, ok := msg.msg.(*lnwire.ChannelUpdate); ok { + if upd, ok := msg.msg.(*lnwire.ChannelUpdate1); ok { if upd.ShortChannelID == firstChanID { t.Fatalf("chan update msg received: %v", spew.Sdump(msg)) @@ -3843,7 +3843,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { // We'll define a helper to assert whether updates should be rate // limited or not depending on their contents. - assertRateLimit := func(update *lnwire.ChannelUpdate, peer lnpeer.Peer, + assertRateLimit := func(update *lnwire.ChannelUpdate1, peer lnpeer.Peer, shouldRateLimit bool) { t.Helper() diff --git a/discovery/message_store.go b/discovery/message_store.go index cf228eee71..80ecf4d7c4 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -55,7 +55,7 @@ type GossipMessageStore interface { // MessageStore is an implementation of the GossipMessageStore interface backed // by a channeldb instance. By design, this store will only keep the latest -// version of a message (like in the case of multiple ChannelUpdate's) for a +// version of a message (like in the case of multiple ChannelUpdate1's) for a // channel with a peer. type MessageStore struct { db kvdb.Backend @@ -85,7 +85,7 @@ func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) { switch msg := msg.(type) { case *lnwire.AnnounceSignatures: shortChanID = msg.ShortChannelID - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: shortChanID = msg.ShortChannelID default: return shortChanID, ErrUnsupportedMessage @@ -157,10 +157,10 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return ErrCorruptedMessageStore } - // In the event that we're attempting to delete a ChannelUpdate + // In the event that we're attempting to delete a ChannelUpdate1 // from the store, we'll make sure that we're actually deleting // the correct one as it can be overwritten. - if msg, ok := msg.(*lnwire.ChannelUpdate); ok { + if msg, ok := msg.(*lnwire.ChannelUpdate1); ok { // Deleting a value from a bucket that doesn't exist // acts as a NOP, so we'll return if a message doesn't // exist under this key. @@ -176,7 +176,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // If the timestamps don't match, then the update stored // should be the latest one, so we'll avoid deleting it. - if msg.Timestamp != dbMsg.(*lnwire.ChannelUpdate).Timestamp { + if msg.Timestamp != dbMsg.(*lnwire.ChannelUpdate1).Timestamp { return nil } } diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index e812c3f1a3..a60e9d72d2 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -59,8 +59,8 @@ func randAnnounceSignatures() *lnwire.AnnounceSignatures { } } -func randChannelUpdate() *lnwire.ChannelUpdate { - return &lnwire.ChannelUpdate{ +func randChannelUpdate() *lnwire.ChannelUpdate1 { + return &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), ExtraOpaqueData: make([]byte, 0), } @@ -118,7 +118,7 @@ func TestMessageStoreMessages(t *testing.T) { switch msg := msg.(type) { case *lnwire.AnnounceSignatures: shortChanID = msg.ShortChannelID.ToUint64() - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: shortChanID = msg.ShortChannelID.ToUint64() default: t.Fatalf("found unexpected message type %T", msg) @@ -297,7 +297,7 @@ func TestMessageStoreDeleteMessage(t *testing.T) { // The store allows overwriting ChannelUpdates, since there can be // multiple versions, so we'll test things slightly different. // - // The ChannelUpdate message should exist within the store after adding + // The ChannelUpdate1 message should exist within the store after adding // it. chanUpdate := randChannelUpdate() if err := msgStore.AddMessage(chanUpdate, peer); err != nil { @@ -305,7 +305,7 @@ func TestMessageStoreDeleteMessage(t *testing.T) { } assertMsg(chanUpdate, peer, true) - // Now, we'll create a new version for the same ChannelUpdate message. + // Now, we'll create a new version for the same ChannelUpdate1 message. // Adding this one to the store will overwrite the previous one, so only // the new one should exist. newChanUpdate := randChannelUpdate() diff --git a/discovery/reliable_sender_test.go b/discovery/reliable_sender_test.go index d1e69b11fb..b95fc87d4a 100644 --- a/discovery/reliable_sender_test.go +++ b/discovery/reliable_sender_test.go @@ -282,7 +282,7 @@ func TestReliableSenderStaleMessages(t *testing.T) { } // Finally, notifying the peer is online should prompt the message to be - // sent. Only the ChannelUpdate will be sent in this case since the + // sent. Only the ChannelUpdate1 will be sent in this case since the // AnnounceSignatures message above was seen as stale. peerChan <- peer diff --git a/discovery/syncer.go b/discovery/syncer.go index c8a727b384..ebf1259756 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1273,9 +1273,9 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // set of channel announcements and channel updates. This will allow us // to quickly check if we should forward a chan ann, based on the known // channel updates for a channel. - chanUpdateIndex := make(map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate) + chanUpdateIndex := make(map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate1) for _, msg := range msgs { - chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate) + chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate1) if !ok { continue } @@ -1345,7 +1345,7 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // For each channel update, we'll only send if it the timestamp // is between our time range. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: if passesFilter(msg.Timestamp) { msgsToSend = append(msgsToSend, msg) } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index ef88e5245a..54fbd4e48e 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -52,7 +52,7 @@ type mockChannelGraphTimeSeries struct { annResp chan []lnwire.Message updateReq chan lnwire.ShortChannelID - updateResp chan []*lnwire.ChannelUpdate + updateResp chan []*lnwire.ChannelUpdate1 } func newMockChannelGraphTimeSeries( @@ -74,7 +74,7 @@ func newMockChannelGraphTimeSeries( annResp: make(chan []lnwire.Message, 1), updateReq: make(chan lnwire.ShortChannelID, 1), - updateResp: make(chan []*lnwire.ChannelUpdate, 1), + updateResp: make(chan []*lnwire.ChannelUpdate1, 1), } } @@ -137,7 +137,7 @@ func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, return <-m.annResp, nil } func (m *mockChannelGraphTimeSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { m.updateReq <- shortChanID @@ -288,7 +288,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(10), Timestamp: unixStamp(5), }, @@ -300,7 +300,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(15), Timestamp: unixStamp(25002), }, @@ -312,7 +312,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), Timestamp: unixStamp(999999), }, @@ -346,7 +346,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { } // If so, then we'll send back the missing update. - chanSeries.updateResp <- []*lnwire.ChannelUpdate{ + chanSeries.updateResp <- []*lnwire.ChannelUpdate1{ { ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), @@ -528,7 +528,7 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { // For this first response, we'll send back a proper // set of messages that should be echoed back. chanSeries.horizonResp <- []lnwire.Message{ - &lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), }, @@ -686,7 +686,7 @@ func TestGossipSyncerReplyShortChanIDs(t *testing.T) { &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), }, - &lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), Timestamp: unixStamp(999999), }, diff --git a/funding/manager.go b/funding/manager.go index c82d81d321..50a83f2fe6 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -3296,12 +3296,12 @@ func (f *Manager) receivedChannelReady(node *btcec.PublicKey, // extractAnnounceParams extracts the various channel announcement and update // parameters that will be needed to construct a ChannelAnnouncement1 and a -// ChannelUpdate. +// ChannelUpdate1. func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( lnwire.MilliSatoshi, lnwire.MilliSatoshi) { // We'll obtain the min HTLC value we can forward in our direction, as - // we'll use this value within our ChannelUpdate. This constraint is + // we'll use this value within our ChannelUpdate1. This constraint is // originally set by the remote node, as it will be the one that will // need to determine the smallest HTLC it deems economically relevant. fwdMinHTLC := c.LocalChanCfg.MinHTLC @@ -3313,7 +3313,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( } // We'll obtain the max HTLC value we can forward in our direction, as - // we'll use this value within our ChannelUpdate. This value must be <= + // we'll use this value within our ChannelUpdate1. This value must be <= // channel capacity and <= the maximum in-flight msats set by the peer. fwdMaxHTLC := c.LocalChanCfg.MaxPendingAmount capacityMSat := lnwire.NewMSatFromSatoshis(c.Capacity) @@ -3324,13 +3324,13 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( return fwdMinHTLC, fwdMaxHTLC } -// addToRouterGraph sends a ChannelAnnouncement1 and a ChannelUpdate to the +// addToRouterGraph sends a ChannelAnnouncement1 and a ChannelUpdate1 to the // gossiper so that the channel is added to the Router's internal graph. // These announcement messages are NOT broadcasted to the greater network, // only to the channel counter party. The proofs required to announce the // channel to the greater network will be created and sent in annAfterSixConfs. // The peerAlias is used for zero-conf channels to give the counter-party a -// ChannelUpdate they understand. ourPolicy may be set for various +// ChannelUpdate1 they understand. ourPolicy may be set for various // option-scid-alias channels to re-use the same policy. func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, @@ -3353,7 +3353,7 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, "announcement: %v", err) } - // Send ChannelAnnouncement1 and ChannelUpdate to the gossiper to add + // Send ChannelAnnouncement1 and ChannelUpdate1 to the gossiper to add // to the Router's topology. errChan := f.cfg.SendAnnouncement( ann.chanAnn, discovery.ChannelCapacity(completeChan.Capacity), @@ -3386,7 +3386,7 @@ func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, routing.ErrIgnored) { log.Debugf("Router rejected "+ - "ChannelUpdate: %v", err) + "ChannelUpdate1: %v", err) } else { return fmt.Errorf("error sending channel "+ "update: %v", err) @@ -3514,7 +3514,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, // We'll delete the edge and add it again via // addToRouterGraph. This is because the peer may have - // sent us a ChannelUpdate with an alias and we don't + // sent us a ChannelUpdate1 with an alias and we don't // want to relay this. ourPolicy, err := f.cfg.DeleteAliasEdge(baseScid) if err != nil { @@ -3762,7 +3762,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen // We'll need to store the received TLV alias if the option_scid_alias // feature was negotiated. This will be used to provide route hints // during invoice creation. In the zero-conf case, it is also used to - // provide a ChannelUpdate to the remote peer. This is done before the + // provide a ChannelUpdate1 to the remote peer. This is done before the // call to InsertNextRevocation in case the call to PutPeerAlias fails. // If it were to fail on the first call to handleChannelReady, we // wouldn't want the channel to be usable yet. @@ -3937,7 +3937,7 @@ func (f *Manager) handleChannelReadyReceived(channel *channeldb.OpenChannel, // We'll need to wait until channel_ready has been received and // the peer lets us know the alias they want to use for the // channel. With this information, we can then construct a - // ChannelUpdate for them. If an alias does not yet exist, + // ChannelUpdate1 for them. If an alias does not yet exist, // we'll just return, letting the next iteration of the loop // check again. var defaultAlias lnwire.ShortChannelID @@ -4051,7 +4051,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // send out to the network after a new channel has been created locally. type chanAnnouncement struct { chanAnn lnwire.ChannelAnnouncement - chanUpdateAnn *lnwire.ChannelUpdate + chanUpdateAnn *lnwire.ChannelUpdate1 chanProof *lnwire.AnnounceSignatures } @@ -4144,8 +4144,8 @@ func (f *Manager) newChanAnnouncement(localPubKey, msgFlags := lnwire.ChanUpdateRequiredMaxHtlc // We announce the channel with the default values. Some of - // these values can later be changed by crafting a new ChannelUpdate. - chanUpdateAnn := &lnwire.ChannelUpdate{ + // these values can later be changed by crafting a new ChannelUpdate1. + chanUpdateAnn := &lnwire.ChannelUpdate1{ ShortChannelID: shortChanID, ChainHash: chainHash, Timestamp: uint32(time.Now().Unix()), @@ -4171,7 +4171,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, switch { case ourPolicy != nil: // If ourPolicy is non-nil, modify the default parameters of the - // ChannelUpdate. + // ChannelUpdate1. chanUpdateAnn.MessageFlags = ourPolicy.MessageFlags chanUpdateAnn.ChannelFlags = ourPolicy.ChannelFlags chanUpdateAnn.TimeLockDelta = ourPolicy.TimeLockDelta @@ -4294,7 +4294,7 @@ func (f *Manager) announceChannel(localIDKey, remoteIDKey *btcec.PublicKey, // We only send the channel proof announcement and the node announcement // because addToRouterGraph previously sent the ChannelAnnouncement1 and - // the ChannelUpdate announcement messages. The channel proof and node + // the ChannelUpdate1 announcement messages. The channel proof and node // announcements are broadcast to the greater network. errChan := f.cfg.SendAnnouncement(ann.chanProof) select { diff --git a/funding/manager_test.go b/funding/manager_test.go index 62cc231acd..9b6b301118 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -1144,9 +1144,9 @@ func assertAddedToRouterGraph(t *testing.T, alice, bob *testNode, } // assertChannelAnnouncements checks that alice and bob both sends the expected -// announcements (ChannelAnnouncement1, ChannelUpdate) after the funding tx has +// announcements (ChannelAnnouncement1, ChannelUpdate1) after the funding tx has // confirmed. The last arguments can be set if we expect the nodes to advertise -// custom min_htlc values as part of their ChannelUpdate. We expect Alice to +// custom min_htlc values as part of their ChannelUpdate1. We expect Alice to // advertise the value required by Bob and vice versa. If they are not set the // advertised value will be checked against the other node's default min_htlc, // base fee and fee rate values. @@ -1176,8 +1176,8 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, // After the ChannelReady message is sent, Alice and Bob will each send // the following messages to their gossiper: // 1) ChannelAnnouncement1 - // 2) ChannelUpdate - // The ChannelAnnouncement1 is kept locally, while the ChannelUpdate is + // 2) ChannelUpdate1 + // The ChannelAnnouncement1 is kept locally, while the ChannelUpdate1 is // sent directly to the other peer, so the edge policies are known to // both peers. nodes := []*testNode{alice, bob} @@ -1197,7 +1197,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, switch m := msg.(type) { case *lnwire.ChannelAnnouncement1: gotChannelAnnouncement = true - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: // The channel update sent by the node should // advertise the MinHTLC value required by the @@ -1246,7 +1246,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, t, gotChannelAnnouncement, "ChannelAnnouncement1 from %d", j, ) - require.Truef(t, gotChannelUpdate, "ChannelUpdate from %d", j) + require.Truef(t, gotChannelUpdate, "ChannelUpdate1 from %d", j) // Make sure no other message is sent. select { @@ -4549,7 +4549,7 @@ func testZeroConf(t *testing.T, chanType *lnwire.ChannelType) { assertHandleChannelReady(t, alice, bob) // We'll now assert that both sides send ChannelAnnouncement1 and - // ChannelUpdate messages. + // ChannelUpdate1 messages. assertChannelAnnouncements( t, alice, bob, fundingAmt, nil, nil, nil, nil, ) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 1ce4ffb2e5..9cadf41400 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -78,7 +78,7 @@ type scidAliasHandler interface { // HTLCs on option_scid_alias channels. attachFailAliasUpdate(failClosure func( sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate) + incoming bool) *lnwire.ChannelUpdate1) // getAliases fetches the link's underlying aliases. This is used by // the Switch to determine whether to forward an HTLC and where to diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 2a1e49923d..3d5e75bd25 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -124,7 +124,7 @@ type ChannelLinkConfig struct { // specified when we receive an incoming HTLC. This will be used to // provide payment senders our latest policy when sending encrypted // error messages. - FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) // Peer is a lightning network node with which we have the channel link // opened. @@ -266,7 +266,7 @@ type ChannelLinkConfig struct { // FailAliasUpdate is a function used to fail an HTLC for an // option_scid_alias channel. FailAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate + incoming bool) *lnwire.ChannelUpdate1 // GetAliases is used by the link and switch to fetch the set of // aliases for a given link. @@ -613,9 +613,9 @@ func shouldAdjustCommitFee(netFee, chanFee, } // failCb is used to cut down on the argument verbosity. -type failCb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage +type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage -// createFailureWithUpdate creates a ChannelUpdate when failing an incoming or +// createFailureWithUpdate creates a ChannelUpdate1 when failing an incoming or // outgoing HTLC. It may return a FailureMessage that references a channel's // alias. If the channel does not have an alias, then the regular channel // update from disk will be returned. @@ -623,7 +623,7 @@ func (l *channelLink) createFailureWithUpdate(incoming bool, outgoingScid lnwire.ShortChannelID, cb failCb) lnwire.FailureMessage { // Determine which SCID to use in case we need to use aliases in the - // ChannelUpdate. + // ChannelUpdate1. scid := outgoingScid if incoming { scid = l.ShortChanID() @@ -2491,7 +2491,7 @@ func (l *channelLink) getAliases() []lnwire.ShortChannelID { // // Part of the scidAliasHandler interface. func (l *channelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { l.Lock() l.cfg.FailAliasUpdate = closure @@ -2568,7 +2568,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewFeeInsufficient(amtToForward, *upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -2596,7 +2596,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Grab the latest routing policy so the sending node is up to // date with our current policy. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewIncorrectCltvExpiry( incomingTimeout, *upd, ) @@ -2645,7 +2645,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewAmountBelowMinimum(amt, *upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -2660,7 +2660,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up-to-date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -2675,7 +2675,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, "outgoing_expiry=%v, best_height=%v", payHash[:], timeout, heightNow) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewExpiryTooSoon(*upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -2695,7 +2695,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, if amt > l.Bandwidth() { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3152,7 +3152,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, l.log.Errorf("unable to encode the "+ "remaining route %v", err) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 9c07cd53c1..187e48dd5b 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -5730,13 +5730,13 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) { // forwarding policy. func TestCheckHtlcForward(t *testing.T) { fetchLastChannelUpdate := func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { - return &lnwire.ChannelUpdate{}, nil + return &lnwire.ChannelUpdate1{}, nil } failAliasUpdate := func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate { + incoming bool) *lnwire.ChannelUpdate1 { return nil } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 9b0453dea9..9cdc583bb7 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -165,7 +165,7 @@ type mockServer struct { var _ lnpeer.Peer = (*mockServer)(nil) func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { - signAliasUpdate := func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return testSig, nil @@ -181,9 +181,9 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) events: make(map[time.Time]channeldb.ForwardingEvent), }, FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { - return &lnwire.ChannelUpdate{ + return &lnwire.ChannelUpdate1{ ShortChannelID: scid, }, nil }, @@ -731,7 +731,7 @@ type mockChannelLink struct { checkHtlcForwardResult *LinkError failAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate + incoming bool) *lnwire.ChannelUpdate1 confirmedZC bool } @@ -860,7 +860,7 @@ func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { } func (f *mockChannelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { f.failAliasUpdate = closure } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 592d03a1ce..94d9aada57 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -170,7 +170,7 @@ type Config struct { // specified when we receive an incoming HTLC. This will be used to // provide payment senders our latest policy when sending encrypted // error messages. - FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) // Notifier is an instance of a chain notifier that we'll use to signal // the switch when a new block has arrived. @@ -216,8 +216,8 @@ type Config struct { // SignAliasUpdate is used when sending FailureMessages backwards for // option_scid_alias channels. This avoids a potential privacy leak by // replacing the public, confirmed SCID with the alias in the - // ChannelUpdate. - SignAliasUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + // ChannelUpdate1. + SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) // IsAlias returns whether or not a given SCID is an alias. @@ -767,7 +767,7 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{}, incomingID := failedPackets[0].incomingChanID // If the incoming channel is an option_scid_alias channel, - // then we'll need to replace the SCID in the ChannelUpdate. + // then we'll need to replace the SCID in the ChannelUpdate1. update := s.failAliasUpdate(incomingID, true) if update == nil { // Fallback to the original non-option behavior. @@ -2848,16 +2848,16 @@ func (s *Switch) failMailboxUpdate(outgoingScid, return lnwire.NewTemporaryChannelFailure(update) } -// failAliasUpdate prepares a ChannelUpdate for a failed incoming or outgoing +// failAliasUpdate prepares a ChannelUpdate1 for a failed incoming or outgoing // HTLC on a channel where the option-scid-alias feature bit was negotiated. If // the associated channel is not one of these, this function will return nil // and the caller is expected to handle this properly. In this case, a return // to the original non-alias behavior is expected. func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate { + incoming bool) *lnwire.ChannelUpdate1 { // This function does not defer the unlocking because of the database - // lookups for ChannelUpdate. + // lookups for ChannelUpdate1. s.indexMtx.RLock() if s.cfg.IsAlias(scid) { @@ -2932,7 +2932,7 @@ func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, } // Fetch the link so we can get an alias to use in the ShortChannelID - // of the ChannelUpdate. + // of the ChannelUpdate1. link, ok := s.forwardingIndex[baseScid] s.indexMtx.RUnlock() if !ok { @@ -2947,14 +2947,14 @@ func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, return nil } - // Fetch the ChannelUpdate via the real, confirmed SCID. + // Fetch the ChannelUpdate1 via the real, confirmed SCID. update, err := s.cfg.FetchLastChannelUpdate(scid) if err != nil { return nil } // The incoming case will replace the ShortChannelID in the retrieved - // ChannelUpdate with the alias to ensure no privacy leak occurs. This + // ChannelUpdate1 with the alias to ensure no privacy leak occurs. This // would happen if a private non-zero-conf option-scid-alias // feature-bit channel leaked its UTXO here rather than supplying an // alias. In the outgoing case, the confirmed SCID was actually used diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 2d507f01b9..eb33c31f9e 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3949,7 +3949,7 @@ func TestSwitchHoldForward(t *testing.T) { // Simulate an error during the composition of the failure message. currentCallback := c.s.cfg.FetchLastChannelUpdate c.s.cfg.FetchLastChannelUpdate = func( - lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { return nil, errors.New("cannot fetch update") } @@ -4820,7 +4820,7 @@ func TestSwitchResolution(t *testing.T) { } // TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure -// before actually forwarding, the ChannelUpdate uses the SCID from the +// before actually forwarding, the ChannelUpdate1 uses the SCID from the // incoming channel and does not leak private information like the UTXO. func TestSwitchForwardFailAlias(t *testing.T) { tests := []struct { @@ -5170,7 +5170,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { select { case failPacket := <-bobLink.packets: // Assert that failPacket returns the expected SCID in the - // ChannelUpdate. + // ChannelUpdate1. msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) @@ -5369,7 +5369,7 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, select { case failPacket := <-bobLink.packets: - // Assert that failPacket returns the expected ChannelUpdate. + // Assert that failPacket returns the expected ChannelUpdate1. msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) @@ -5514,7 +5514,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { select { case failPacket := <-aliceLink.packets: - // Assert that failPacket returns the expected ChannelUpdate. + // Assert that failPacket returns the expected ChannelUpdate1. failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC) require.True(t, ok) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index da36e214cd..2848920713 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -91,8 +91,8 @@ func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, // mockGetChanUpdateMessage helper function which returns topology update of // the channel -func mockGetChanUpdateMessage(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - return &lnwire.ChannelUpdate{ +func mockGetChanUpdateMessage(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { + return &lnwire.ChannelUpdate1{ Signature: wireSig, }, nil } diff --git a/itest/lnd_channel_policy_test.go b/itest/lnd_channel_policy_test.go index 85527a0821..6f819308da 100644 --- a/itest/lnd_channel_policy_test.go +++ b/itest/lnd_channel_policy_test.go @@ -84,7 +84,7 @@ func testUpdateChannelPolicy(ht *lntest.HarnessTest) { // Open the channel Carol->Bob with a custom min_htlc value set. Since // Carol is opening the channel, she will require Bob to not forward // HTLCs smaller than this value, and hence he should advertise it as - // part of his ChannelUpdate. + // part of his ChannelUpdate1. const customMinHtlc = 5000 chanPoint2 := ht.OpenChannel( carol, bob, lntest.OpenChannelParams{ @@ -540,7 +540,7 @@ func testSendUpdateDisableChannel(ht *lntest.HarnessTest) { assertPolicyUpdate(eve, expectedPolicy, chanPointEveCarol, 2) // We restart Carol. Since the channel now becomes active again, Eve - // should send a ChannelUpdate setting the channel no longer disabled. + // should send a ChannelUpdate1 setting the channel no longer disabled. require.NoError(ht, restartCarol(), "unable to restart carol") expectedPolicy.Disabled = false diff --git a/itest/lnd_zero_conf_test.go b/itest/lnd_zero_conf_test.go index f2190bd2a9..fd15ec7aa3 100644 --- a/itest/lnd_zero_conf_test.go +++ b/itest/lnd_zero_conf_test.go @@ -488,7 +488,7 @@ func testPrivateUpdateAlias(ht *lntest.HarnessTest, ht.EnsureConnected(carol, dave) // We'll open a regular public channel between Eve and Carol here. Eve - // will be the one receiving the onion-encrypted ChannelUpdate. + // will be the one receiving the onion-encrypted ChannelUpdate1. ht.EnsureConnected(eve, carol) chanAmt := btcutil.Amount(1_000_000) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index feb3d4b856..1e9e26e52d 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1499,7 +1499,7 @@ func marshallWireError(msg lnwire.FailureMessage, // marshallChannelUpdate marshalls a channel update as received over the wire to // the router rpc format. -func marshallChannelUpdate(update *lnwire.ChannelUpdate) *lnrpc.ChannelUpdate { +func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { if update == nil { return nil } diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 7f42a58b46..95f946da07 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -9,12 +9,12 @@ import ( ) // ChanUpdateMsgFlags is a bitfield that signals whether optional fields are -// present in the ChannelUpdate. +// present in the ChannelUpdate1. type ChanUpdateMsgFlags uint8 const ( // ChanUpdateRequiredMaxHtlc is a bit that indicates whether the - // required htlc_maximum_msat field is present in this ChannelUpdate. + // required htlc_maximum_msat field is present in this ChannelUpdate1. ChanUpdateRequiredMaxHtlc ChanUpdateMsgFlags = 1 << iota ) @@ -31,7 +31,7 @@ func (c ChanUpdateMsgFlags) HasMaxHtlc() bool { // ChanUpdateChanFlags is a bitfield that signals various options concerning a // particular channel edge. Each bit is to be examined in order to determine -// how the ChannelUpdate message is to be interpreted. +// how the ChannelUpdate1 message is to be interpreted. type ChanUpdateChanFlags uint8 const ( @@ -56,11 +56,11 @@ func (c ChanUpdateChanFlags) String() string { return fmt.Sprintf("%08b", c) } -// ChannelUpdate message is used after channel has been initially announced. +// ChannelUpdate1 message is used after channel has been initially announced. // Each side independently announces its fees and minimum expiry for HTLCs and // other parameters. Also this message is used to redeclare initially set // channel parameters. -type ChannelUpdate struct { +type ChannelUpdate1 struct { // Signature is used to validate the announced data and prove the // ownership of node id. Signature Sig @@ -120,15 +120,15 @@ type ChannelUpdate struct { ExtraOpaqueData ExtraOpaqueData } -// A compile time check to ensure ChannelUpdate implements the lnwire.Message +// A compile time check to ensure ChannelUpdate1 implements the lnwire.Message // interface. -var _ Message = (*ChannelUpdate)(nil) +var _ Message = (*ChannelUpdate1)(nil) -// Decode deserializes a serialized ChannelUpdate stored in the passed +// Decode deserializes a serialized ChannelUpdate1 stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { +func (a *ChannelUpdate1) Decode(r io.Reader, _ uint32) error { err := ReadElements(r, &a.Signature, a.ChainHash[:], @@ -155,11 +155,11 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { return a.ExtraOpaqueData.Decode(r) } -// Encode serializes the target ChannelUpdate into the passed io.Writer +// Encode serializes the target ChannelUpdate1 into the passed io.Writer // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { +func (a *ChannelUpdate1) Encode(w *bytes.Buffer, _ uint32) error { if err := WriteSig(w, a.Signature); err != nil { return err } @@ -217,13 +217,13 @@ func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { // wire. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) MsgType() MessageType { +func (a *ChannelUpdate1) MsgType() MessageType { return MsgChannelUpdate } // DataToSign is used to retrieve part of the announcement message which should // be signed. -func (a *ChannelUpdate) DataToSign() ([]byte, error) { +func (a *ChannelUpdate1) DataToSign() ([]byte, error) { // We should not include the signatures itself. b := make([]byte, 0, MaxMsgBody) buf := bytes.NewBuffer(b) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e559d064cc..5e15469ac7 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -876,14 +876,14 @@ func TestLightningWireProtocol(t *testing.T) { maxHtlc := MilliSatoshi(r.Int63()) // We make the max_htlc field zero if it is not flagged - // as being part of the ChannelUpdate, to pass + // as being part of the ChannelUpdate1, to pass // serialization tests, as it will be ignored if the bit // is not set. if msgFlags&ChanUpdateRequiredMaxHtlc == 0 { maxHtlc = 0 } - req := ChannelUpdate{ + req := ChannelUpdate1{ ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), Timestamp: uint32(r.Int31()), MessageFlags: msgFlags, @@ -1418,7 +1418,7 @@ func TestLightningWireProtocol(t *testing.T) { }, { msgType: MsgChannelUpdate, - scenario: func(m ChannelUpdate) bool { + scenario: func(m ChannelUpdate1) bool { return mainScenario(&m) }, }, diff --git a/lnwire/message.go b/lnwire/message.go index 1a2a2ad9fe..7d7e73bde2 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -117,7 +117,7 @@ func (t MessageType) String() string { case MsgChannelAnnouncement: return "ChannelAnnouncement1" case MsgChannelUpdate: - return "ChannelUpdate" + return "ChannelUpdate1" case MsgNodeAnnouncement: return "NodeAnnouncement" case MsgPing: @@ -229,7 +229,7 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { case MsgChannelAnnouncement: msg = &ChannelAnnouncement1{} case MsgChannelUpdate: - msg = &ChannelUpdate{} + msg = &ChannelUpdate1{} case MsgNodeAnnouncement: msg = &NodeAnnouncement{} case MsgPing: diff --git a/lnwire/message_test.go b/lnwire/message_test.go index 540539b28a..e4268e5db9 100644 --- a/lnwire/message_test.go +++ b/lnwire/message_test.go @@ -692,21 +692,21 @@ func newMsgNodeAnnouncement(t testing.TB, return msg } -func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate { +func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate1 { t.Helper() msgFlags := lnwire.ChanUpdateMsgFlags(r.Int31()) maxHtlc := lnwire.MilliSatoshi(r.Int63()) // We make the max_htlc field zero if it is not flagged - // as being part of the ChannelUpdate, to pass + // as being part of the ChannelUpdate1, to pass // serialization tests, as it will be ignored if the bit // is not set. if msgFlags&lnwire.ChanUpdateRequiredMaxHtlc == 0 { maxHtlc = 0 } - msg := &lnwire.ChannelUpdate{ + msg := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(r.Uint64()), Timestamp: uint32(r.Int31()), MessageFlags: msgFlags, diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index adf12781ac..608ac9d4ca 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -597,7 +597,7 @@ func (f *FailInvalidOnionKey) Error() string { // unable to pull out a fully valid version, then we'll fall back to the // regular parsing mechanism which includes the length prefix an NO type byte. func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, - chanUpdate *ChannelUpdate, pver uint32) error { + chanUpdate *ChannelUpdate1, pver uint32) error { // Instantiate a LimitReader because there may be additional data // present after the channel update. Without limiting the stream, the @@ -615,7 +615,7 @@ func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, // Some nodes well prefix an additional set of bytes in front of their // channel updates. These bytes will _almost_ always be 258 or the type - // of the ChannelUpdate message. + // of the ChannelUpdate1 message. typeInt := binary.BigEndian.Uint16(maybeTypeBytes) if typeInt == MsgChannelUpdate { // At this point it's likely the case that this is a channel @@ -644,11 +644,11 @@ type FailTemporaryChannelFailure struct { // which caused the failure. // // NOTE: This field is optional. - Update *ChannelUpdate + Update *ChannelUpdate1 } // NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. -func NewTemporaryChannelFailure(update *ChannelUpdate) *FailTemporaryChannelFailure { +func NewTemporaryChannelFailure(update *ChannelUpdate1) *FailTemporaryChannelFailure { return &FailTemporaryChannelFailure{Update: update} } @@ -682,7 +682,7 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { } if length != 0 { - f.Update = &ChannelUpdate{} + f.Update = &ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, f.Update, pver, @@ -717,12 +717,12 @@ type FailAmountBelowMinimum struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. func NewAmountBelowMinimum(htlcMsat MilliSatoshi, - update ChannelUpdate) *FailAmountBelowMinimum { + update ChannelUpdate1) *FailAmountBelowMinimum { return &FailAmountBelowMinimum{ HtlcMsat: htlcMsat, @@ -758,7 +758,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -787,12 +787,12 @@ type FailFeeInsufficient struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewFeeInsufficient creates new instance of the FailFeeInsufficient. func NewFeeInsufficient(htlcMsat MilliSatoshi, - update ChannelUpdate) *FailFeeInsufficient { + update ChannelUpdate1) *FailFeeInsufficient { return &FailFeeInsufficient{ HtlcMsat: htlcMsat, Update: update, @@ -827,7 +827,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -858,12 +858,12 @@ type FailIncorrectCltvExpiry struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. func NewIncorrectCltvExpiry(cltvExpiry uint32, - update ChannelUpdate) *FailIncorrectCltvExpiry { + update ChannelUpdate1) *FailIncorrectCltvExpiry { return &FailIncorrectCltvExpiry{ CltvExpiry: cltvExpiry, @@ -896,7 +896,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -921,11 +921,11 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { type FailExpiryTooSoon struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. -func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { +func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon { return &FailExpiryTooSoon{ Update: update, } @@ -954,7 +954,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -980,11 +980,11 @@ type FailChannelDisabled struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewChannelDisabled creates new instance of the FailChannelDisabled. -func NewChannelDisabled(flags uint16, update ChannelUpdate) *FailChannelDisabled { +func NewChannelDisabled(flags uint16, update ChannelUpdate1) *FailChannelDisabled { return &FailChannelDisabled{ Flags: flags, Update: update, @@ -1019,7 +1019,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -1456,10 +1456,10 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { } } -// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error +// writeOnionErrorChanUpdate writes out a ChannelUpdate1 using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 438306b666..4f230bd22f 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,7 +20,7 @@ var ( testType = uint64(3) testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) - testChannelUpdate = ChannelUpdate{ + testChannelUpdate = ChannelUpdate1{ Signature: sig, ShortChannelID: NewShortChanIDFromInt(1), Timestamp: 1, @@ -136,7 +136,7 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // Now that we have the set of bytes encoded, we'll ensure that we're // able to decode it using our compatibility method, as it's a regular // encoded channel update message. - var newChanUpdate ChannelUpdate + var newChanUpdate ChannelUpdate1 err := parseChannelUpdateCompatibilityMode( &b, uint16(b.Len()), &newChanUpdate, 0, ) @@ -163,7 +163,7 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // We should be able to properly parse the encoded channel update // message even with the extra two bytes. - var newChanUpdate2 ChannelUpdate + var newChanUpdate2 ChannelUpdate1 err = parseChannelUpdateCompatibilityMode( &b, uint16(b.Len()), &newChanUpdate2, 0, ) diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 43b21fda7d..1b746b2365 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -60,7 +60,7 @@ type ChanStatusConfig struct { // ApplyChannelUpdate processes new ChannelUpdates signed by our node by // updating our local routing table and broadcasting the update to our // peers. - ApplyChannelUpdate func(*lnwire.ChannelUpdate, *wire.OutPoint, + ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint, bool) error // DB stores the set of channels that are to be monitored. @@ -90,7 +90,7 @@ type ChanStatusConfig struct { } // ChanStatusManager facilitates requests to enable or disable a channel via a -// network announcement that sets the disable bit on the ChannelUpdate +// network announcement that sets the disable bit on the ChannelUpdate1 // accordingly. The manager will periodically sample to detect cases where a // link has become inactive, and facilitate the process of disabling the channel // passively. The ChanStatusManager state machine is designed to reduce the @@ -202,7 +202,7 @@ func (m *ChanStatusManager) start() error { continue // If we are in the process of opening a channel, the funding - // manager might not have added the ChannelUpdate to the graph + // manager might not have added the ChannelUpdate1 to the graph // yet. We'll ignore the channel for now. case err == ErrUnableToExtractChanUpdate: log.Warnf("Unable to find channel policies for %v, "+ @@ -646,11 +646,11 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, } // fetchLastChanUpdateByOutPoint fetches the latest policy for our direction of -// a channel, and crafts a new ChannelUpdate with this policy. Returns an error +// a channel, and crafts a new ChannelUpdate1 with this policy. Returns an error // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate, bool, error) { + *lnwire.ChannelUpdate1, bool, error) { // Get the edge info and policies for this channel from the graph. info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op) diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index e4cedd25ab..9ff359f2cc 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -127,7 +127,7 @@ type mockGraph struct { chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy sidToCid map[lnwire.ShortChannelID]wire.OutPoint - updates chan *lnwire.ChannelUpdate + updates chan *lnwire.ChannelUpdate1 } func newMockGraph(t *testing.T, numChannels int, @@ -139,7 +139,7 @@ func newMockGraph(t *testing.T, numChannels int, chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), - updates: make(chan *lnwire.ChannelUpdate, 2*numChannels), + updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), } for i := 0; i < numChannels; i++ { @@ -178,7 +178,7 @@ func (g *mockGraph) FetchChannelEdgesByOutpoint( return info, pol1, pol2, nil } -func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate, +func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, op *wire.OutPoint, private bool) error { g.mu.Lock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index ee3e4791c2..1da5f610b9 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -17,7 +17,7 @@ import ( func CreateChanAnnouncement(chanProof models.ChannelAuthProof, chanInfo models.ChannelEdgeInfo, e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { switch proof := chanProof.(type) { case *channeldb.ChannelAuthProof1: @@ -39,7 +39,7 @@ func CreateChanAnnouncement(chanProof models.ChannelAuthProof, func createChanAnnouncement1(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { // First, using the parameters of the channel, along with the channel // authentication chanProof, we'll create re-create the original @@ -92,7 +92,7 @@ func createChanAnnouncement1(chanProof *channeldb.ChannelAuthProof1, // Since it's up to a node's policy as to whether they advertise the // edge in a direction, we don't create an advertisement if the edge is // nil. - var edge1Ann, edge2Ann *lnwire.ChannelUpdate + var edge1Ann, edge2Ann *lnwire.ChannelUpdate1 if e1 != nil { edge1Ann, err = ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/netann/channel_update.go b/netann/channel_update.go index 0e64613476..73f0f9996a 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -16,16 +16,16 @@ import ( // ErrUnableToExtractChanUpdate is returned when a channel update cannot be // found for one of our active channels. -var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate") +var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate1") // ChannelUpdateModifier is a closure that makes in-place modifications to an -// lnwire.ChannelUpdate. -type ChannelUpdateModifier func(*lnwire.ChannelUpdate) +// lnwire.ChannelUpdate1. +type ChannelUpdateModifier func(*lnwire.ChannelUpdate1) // ChanUpdSetDisable is a functional option that sets the disabled channel flag // if disabled is true, and clears the bit otherwise. func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { - return func(update *lnwire.ChannelUpdate) { + return func(update *lnwire.ChannelUpdate1) { if disabled { // Set the bit responsible for marking a channel as // disabled. @@ -41,7 +41,7 @@ func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { // ChanUpdSetTimestamp is a functional option that sets the timestamp of the // update to the current time, or increments it if the timestamp is already in // the future. -func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate) { +func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { newTimestamp := uint32(time.Now().Unix()) if newTimestamp <= update.Timestamp { // Increment the prior value to ensure the timestamp @@ -53,13 +53,13 @@ func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate) { } // SignChannelUpdate applies the given modifiers to the passed -// lnwire.ChannelUpdate, then signs the resulting update. The provided update +// lnwire.ChannelUpdate1, then signs the resulting update. The provided update // should be the most recent, valid update, otherwise the timestamp may not // monotonically increase from the prior. // // NOTE: This method modifies the given update. func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, - update *lnwire.ChannelUpdate, mods ...ChannelUpdateModifier) error { + update *lnwire.ChannelUpdate1, mods ...ChannelUpdateModifier) error { // Apply the requested changes to the channel update. for _, modifier := range mods { @@ -81,14 +81,14 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator return nil } -// ExtractChannelUpdate attempts to retrieve a lnwire.ChannelUpdate message from +// ExtractChannelUpdate attempts to retrieve a lnwire.ChannelUpdate1 message from // an edge's info and a set of routing policies. // // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, info models.ChannelEdgeInfo, policies ...*channeldb.ChannelEdgePolicy) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. owner := func(edge *channeldb.ChannelEdgePolicy) []byte { @@ -117,12 +117,12 @@ func ExtractChannelUpdate(ownerPubKey []byte, return nil, ErrUnableToExtractChanUpdate } -// UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the +// UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate1 from the // given edge info and policy. func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, - policy *channeldb.ChannelEdgePolicy) *lnwire.ChannelUpdate { + policy *channeldb.ChannelEdgePolicy) *lnwire.ChannelUpdate1 { - return &lnwire.ChannelUpdate{ + return &lnwire.ChannelUpdate1{ ChainHash: chainHash, ShortChannelID: lnwire.NewShortChanIDFromInt(policy.ChannelID), Timestamp: uint32(policy.LastUpdate.Unix()), @@ -137,10 +137,10 @@ func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, } } -// ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given edge +// ChannelUpdateFromEdge reconstructs a signed ChannelUpdate1 from the given edge // info and policy. func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, - policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate, error) { + policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { update := UnsignedChannelUpdateFromEdge(info.GetChainHash(), policy) diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go index e49e5c65e8..a32d96d88b 100644 --- a/netann/channel_update_test.go +++ b/netann/channel_update_test.go @@ -111,7 +111,7 @@ func TestUpdateDisableFlag(t *testing.T) { // Create the initial update, the only fields we are // concerned with in this test are the timestamp and the // channel flags. - ogUpdate := &lnwire.ChannelUpdate{ + ogUpdate := &lnwire.ChannelUpdate1{ Timestamp: uint32(tc.startTime.Unix()), } if !tc.startEnabled { @@ -122,7 +122,7 @@ func TestUpdateDisableFlag(t *testing.T) { // the original. UpdateDisableFlag will mutate the // passed channel update, so we keep the old one to test // against. - newUpdate := &lnwire.ChannelUpdate{ + newUpdate := &lnwire.ChannelUpdate1{ Timestamp: ogUpdate.Timestamp, ChannelFlags: ogUpdate.ChannelFlags, } diff --git a/netann/sign.go b/netann/sign.go index 93bd8cdc9b..66266310df 100644 --- a/netann/sign.go +++ b/netann/sign.go @@ -22,7 +22,7 @@ func SignAnnouncement(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, switch m := msg.(type) { case *lnwire.ChannelAnnouncement1: data, err = m.DataToSign() - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: data, err = m.DataToSign() case *lnwire.NodeAnnouncement: data, err = m.DataToSign() diff --git a/peer/brontide.go b/peer/brontide.go index 32e88fade5..9b6e4f74e1 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -290,7 +290,7 @@ type Config struct { // FetchLastChanUpdate fetches our latest channel update for a target // channel. - FetchLastChanUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, + FetchLastChanUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) // FundingManager is an implementation of the funding.Controller interface. @@ -1725,7 +1725,7 @@ out: nextMsg.MsgType()) } - case *lnwire.ChannelUpdate, + case *lnwire.ChannelUpdate1, lnwire.ChannelAnnouncement, *lnwire.NodeAnnouncement, *lnwire.AnnounceSignatures, @@ -1983,7 +1983,7 @@ func messageSummary(msg lnwire.Message) string { return fmt.Sprintf("chain_hash=%v, short_chan_id=%v", msg.GetChainHash(), msg.SCID().ToUint64()) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+ "mflags=%v, cflags=%v, update_time=%v", msg.ChainHash, msg.ShortChannelID.ToUint64(), msg.MessageFlags, @@ -2501,7 +2501,7 @@ out: // reenableActiveChannels searches the index of channels maintained with this // peer, and reenables each public, non-pending channel. This is done at the -// gossip level by broadcasting a new ChannelUpdate with the disabled bit unset. +// gossip level by broadcasting a new ChannelUpdate1 with the disabled bit unset. // No message will be sent if the channel is already enabled. func (p *Brontide) reenableActiveChannels() { // First, filter all known channels with this peer for ones that are @@ -2512,7 +2512,7 @@ func (p *Brontide) reenableActiveChannels() { retryChans := make(map[wire.OutPoint]struct{}, len(activePublicChans)) // For each of the public, non-pending channels, set the channel - // disabled bit to false and send out a new ChannelUpdate. If this + // disabled bit to false and send out a new ChannelUpdate1. If this // channel is already active, the update won't be sent. for _, chanPoint := range activePublicChans { err := p.cfg.ChanStatusMgr.RequestEnable(chanPoint, false) diff --git a/peer/test_utils.go b/peer/test_utils.go index add15cf19d..41d7546edd 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -337,7 +337,7 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, OurPubKey: aliceKeyPub, OurKeyLoc: testKeyLoc, IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate, + ApplyChannelUpdate: func(*lnwire.ChannelUpdate1, *wire.OutPoint, bool) error { return nil diff --git a/routing/ann_validation.go b/routing/ann_validation.go index 5e3b7616c4..85f8d42833 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -194,7 +194,7 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { // signed by the node's private key, and (2) that the announcement's message // flags and optional fields are sane. func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, - a *lnwire.ChannelUpdate) error { + a *lnwire.ChannelUpdate1) error { if err := ValidateChannelUpdateFields(capacity, a); err != nil { return err @@ -205,7 +205,7 @@ func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, // VerifyChannelUpdateSignature verifies that the channel update message was // signed by the party with the given node public key. -func VerifyChannelUpdateSignature(msg *lnwire.ChannelUpdate, +func VerifyChannelUpdateSignature(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey) error { data, err := msg.DataToSign() @@ -230,7 +230,7 @@ func VerifyChannelUpdateSignature(msg *lnwire.ChannelUpdate, // ValidateChannelUpdateFields validates a channel update's message flags and // corresponding update fields. func ValidateChannelUpdateFields(capacity btcutil.Amount, - msg *lnwire.ChannelUpdate) error { + msg *lnwire.ChannelUpdate1) error { // The maxHTLC flag is mandatory. if !msg.MessageFlags.HasMaxHtlc() { diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index 27391d53e2..4a0f738715 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -197,7 +197,7 @@ func TestMissionControl(t *testing.T) { // A node level failure should bring probability of all known channels // back to zero. - ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate{})) + ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{})) ctx.expectP(1000, 0) // Check whether history snapshot looks sane. @@ -219,14 +219,14 @@ func TestMissionControlChannelUpdate(t *testing.T) { // Report a policy related failure. Because it is the first, we don't // expect a penalty. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), + 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), ) ctx.expectP(100, testAprioriHopProbability) // Report another failure for the same channel. We expect it to be // pruned. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), + 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), ) ctx.expectP(100, 0) } diff --git a/routing/mock_test.go b/routing/mock_test.go index c24ca2508e..4f765d89e9 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -173,7 +173,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, return r, nil } -func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, +func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1, _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool { return false @@ -675,7 +675,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, return args.Get(0).(*route.Route), args.Error(1) } -func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, +func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) diff --git a/routing/payment_session.go b/routing/payment_session.go index 0a90785022..e54d5e2760 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -143,7 +143,7 @@ type PaymentSession interface { // (private channels) and applies the update from the message. Returns // a boolean to indicate whether the update has been applied without // error. - UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, + UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query @@ -404,7 +404,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // validates the message signature and checks it's up to date, then applies the // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. -func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, +func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { // Validate the message signature. diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 18858f1f92..6797f14b28 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -148,7 +148,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { ) // Create the channel update message and sign. - msg := &lnwire.ChannelUpdate{ + msg := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID), Timestamp: uint32(time.Now().Unix()), BaseFee: newFeeBaseMSat, diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index b9fcc68d14..8eb39c01d9 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -94,7 +94,7 @@ var resultTestCases = []resultTestCase{ name: "fail expiry too soon", route: &routeFourHop, failureSrcIdx: 3, - failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate{}), + failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ @@ -196,7 +196,7 @@ var resultTestCases = []resultTestCase{ name: "fail fee insufficient intermediate", route: &routeFourHop, failureSrcIdx: 2, - failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), + failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ diff --git a/routing/router.go b/routing/router.go index 37674e515e..ec4f3caa19 100644 --- a/routing/router.go +++ b/routing/router.go @@ -286,7 +286,7 @@ type FeeSchema struct { // ChannelPolicy holds the parameters that determine the policy we enforce // when forwarding payments on a channel. These parameters are communicated -// to the rest of the network in ChannelUpdate messages. +// to the rest of the network in ChannelUpdate1 messages. type ChannelPolicy struct { // FeeSchema holds the fee configuration for a channel. FeeSchema @@ -2563,9 +2563,9 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, // extractChannelUpdate examines the error and extracts the channel update. func (r *ChannelRouter) extractChannelUpdate( - failure lnwire.FailureMessage) *lnwire.ChannelUpdate { + failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 { - var update *lnwire.ChannelUpdate + var update *lnwire.ChannelUpdate1 switch onionErr := failure.(type) { case *lnwire.FailExpiryTooSoon: update = &onionErr.Update @@ -2586,7 +2586,7 @@ func (r *ChannelRouter) extractChannelUpdate( // applyChannelUpdate validates a channel update and if valid, applies it to the // database. It returns a bool indicating whether the updates were successful. -func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate) bool { +func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) if err != nil { log.Errorf("Unable to retrieve channel by id: %v", err) diff --git a/routing/router_test.go b/routing/router_test.go index 8bcd79627c..0d7f12a788 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -236,7 +236,7 @@ func createTestCtxFromFile(t *testing.T, // Add valid signature to channel update simulated as error received from the // network. func signErrChanUpdate(t *testing.T, key *btcec.PrivateKey, - errChanUpdate *lnwire.ChannelUpdate) { + errChanUpdate *lnwire.ChannelUpdate1) { chanUpdateMsg, err := errChanUpdate.DataToSign() require.NoError(t, err, "failed to retrieve data to sign") @@ -510,7 +510,7 @@ func TestChannelUpdateValidation(t *testing.T) { // Set up a channel update message with an invalid signature to be // returned to the sender. var invalidSignature lnwire.Sig - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ Signature: invalidSignature, FeeRate: 500, ShortChannelID: lnwire.NewShortChanIDFromInt(1), @@ -612,7 +612,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { ) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt( songokuSophonChanID, ), @@ -731,7 +731,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // Prepare an error update for the private channel, with twice the // original fee. updatedFeeBaseMSat := feeBaseMSat * 2 - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(privateChannelID), Timestamp: uint32(testTime.Add(time.Minute).Unix()), BaseFee: updatedFeeBaseMSat, @@ -803,7 +803,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { } // TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit tests that upon receiving a -// ChannelUpdate in a fee related error from the private channel, we won't +// ChannelUpdate1 in a fee related error from the private channel, we won't // choose the route in our second attempt if the updated fee exceeds our fee // limit specified in the payment. // @@ -860,7 +860,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // Prepare an error update for the private channel. The updated fee // will exceeds the feeLimit. updatedFeeBaseMSat := feeBaseMSat + uint32(feeLimit) - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(privateChannelID), Timestamp: uint32(testTime.Add(time.Minute).Unix()), BaseFee: updatedFeeBaseMSat, @@ -964,7 +964,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { _, _, edgeUpdateToFail, err := ctx.graph.FetchChannelEdgesByID(chanID) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(chanID), Timestamp: uint32(edgeUpdateToFail.LastUpdate.Unix()), MessageFlags: edgeUpdateToFail.MessageFlags, @@ -2921,7 +2921,7 @@ func TestSendToRouteStructuredError(t *testing.T) { testCases := map[int]lnwire.FailureMessage{ finalHopIndex: lnwire.NewFailIncorrectDetails(payAmt, 100), 1: &lnwire.FailFeeInsufficient{ - Update: lnwire.ChannelUpdate{}, + Update: lnwire.ChannelUpdate1{}, }, } diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index 24423402aa..783aa78bfe 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -146,7 +146,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // initialization needs to be done beyond just occupying a job slot. case *channeldb.ChannelEdgePolicy: return - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return case *lnwire.NodeAnnouncement: // TODO(roasbeef): node ann needs to wait on existing channel updates @@ -186,7 +186,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { v.Lock() switch msg := job.(type) { - // Any ChannelUpdate or NodeAnnouncement jobs will need to wait on the + // Any ChannelUpdate1 or NodeAnnouncement jobs will need to wait on the // completion of any active ChannelAnnouncement1 jobs related to them. case *channeldb.ChannelEdgePolicy: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) @@ -202,10 +202,10 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { jobDesc = fmt.Sprintf("job=channeldb.LightningNode, pub=%s", vertex) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: signals, ok = v.chanEdgeDependencies[msg.ShortChannelID] - jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate, scid=%v", + jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate1, scid=%v", msg.ShortChannelID.ToUint64()) case *lnwire.NodeAnnouncement: @@ -297,7 +297,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes)) case *lnwire.NodeAnnouncement: delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: delete(v.chanEdgeDependencies, msg.ShortChannelID) case *channeldb.ChannelEdgePolicy: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) diff --git a/routing/validation_barrier_test.go b/routing/validation_barrier_test.go index e56c95a472..a26ba4d9a0 100644 --- a/routing/validation_barrier_test.go +++ b/routing/validation_barrier_test.go @@ -85,9 +85,9 @@ func TestValidationBarrierQuit(t *testing.T) { // Create a set of channel updates, that must wait until their // associated channel announcement has been verified. - chanUpds := make([]*lnwire.ChannelUpdate, 0, numTasks) + chanUpds := make([]*lnwire.ChannelUpdate1, 0, numTasks) for i := 0; i < numTasks; i++ { - chanUpds = append(chanUpds, &lnwire.ChannelUpdate{ + chanUpds = append(chanUpds, &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), }) barrier.InitJobDependencies(chanUpds[i]) diff --git a/server.go b/server.go index 5ef7886fbf..fe4f2ef478 100644 --- a/server.go +++ b/server.go @@ -1662,10 +1662,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return s, nil } -// signAliasUpdate takes a ChannelUpdate and returns the signature. This is -// used for option_scid_alias channels where the ChannelUpdate to be sent back +// signAliasUpdate takes a ChannelUpdate1 and returns the signature. This is +// used for option_scid_alias channels where the ChannelUpdate1 to be sent back // may differ from what is on disk. -func (s *server) signAliasUpdate(u *lnwire.ChannelUpdate) (*ecdsa.Signature, +func (s *server) signAliasUpdate(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { data, err := u.DataToSign() @@ -4614,10 +4614,10 @@ func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, err // fetchLastChanUpdate returns a function which is able to retrieve our latest // channel update for a target channel. func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { ourPubKey := s.identityECDH.PubKey().SerializeCompressed() - return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid) if err != nil { return nil, err @@ -4632,7 +4632,7 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( // applyChannelUpdate applies the channel update to the different sub-systems of // the server. The useAlias boolean denotes whether or not to send an alias in // place of the real SCID. -func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate, +func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, op *wire.OutPoint, useAlias bool) error { var ( @@ -4643,7 +4643,7 @@ func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate, chanID := lnwire.NewChanIDFromOutPoint(op) // Fetch the peer's alias from the lnwire.ChannelID so it can be used - // in the ChannelUpdate if it hasn't been announced yet. + // in the ChannelUpdate1 if it hasn't been announced yet. if useAlias { foundAlias, _ := s.aliasMgr.GetPeerAlias(chanID) if foundAlias != defaultAlias { From 1e311bc2ff3f69317586c3d49a4c35d6584487da Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 16:10:37 +0200 Subject: [PATCH 23/33] multi: rename ChannelEdgePolicy to ChannelEdgePolicy1 --- autopilot/graph.go | 6 +- channeldb/graph.go | 84 ++++++++++++++-------------- channeldb/graph_cache.go | 14 ++--- channeldb/graph_cache_test.go | 16 +++--- channeldb/graph_test.go | 40 ++++++------- discovery/gossiper.go | 16 +++--- discovery/gossiper_test.go | 32 +++++------ funding/manager.go | 6 +- funding/manager_test.go | 2 +- lnrpc/devrpc/dev_server.go | 4 +- lnrpc/invoicesrpc/addinvoice.go | 10 ++-- lnrpc/invoicesrpc/addinvoice_test.go | 60 ++++++++++---------- netann/chan_status_manager.go | 2 +- netann/chan_status_manager_test.go | 22 ++++---- netann/channel_announcement.go | 4 +- netann/channel_update.go | 8 +-- netann/interface.go | 2 +- routing/localchans/manager.go | 6 +- routing/localchans/manager_test.go | 6 +- routing/notifications.go | 2 +- routing/notifications_test.go | 6 +- routing/pathfind.go | 4 +- routing/pathfind_test.go | 6 +- routing/router.go | 28 +++++----- routing/router_test.go | 14 ++--- routing/validation_barrier.go | 8 +-- rpcserver.go | 10 ++-- server.go | 6 +- 28 files changed, 212 insertions(+), 212 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index e30639fea8..aba7d8365a 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -91,7 +91,7 @@ func (d dbNode) Addrs() []net.Addr { func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.node.ForEachChannel(d.db, d.tx, func(db kvdb.Backend, tx kvdb.RTx, ei models.ChannelEdgeInfo, ep, - _ *channeldb.ChannelEdgePolicy) error { + _ *channeldb.ChannelEdgePolicy1) error { // Skip channels for which no outgoing edge policy is available. // @@ -236,7 +236,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err := d.db.AddChannelEdge(edge); err != nil { return nil, nil, err } - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), @@ -252,7 +252,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { return nil, nil, err } - edgePolicy = &channeldb.ChannelEdgePolicy{ + edgePolicy = &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), diff --git a/channeldb/graph.go b/channeldb/graph.go index d0a440d313..4b919350e0 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -239,7 +239,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, } err = g.ForEachChannel(func(info models.ChannelEdgeInfo, - policy1, policy2 *ChannelEdgePolicy) error { + policy1, policy2 *ChannelEdgePolicy1) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -269,10 +269,10 @@ type channelMapKey struct { // getChannelMap loads all channel edge policies from the database and stores // them in a map. func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( - map[channelMapKey]*ChannelEdgePolicy, error) { + map[channelMapKey]*ChannelEdgePolicy1, error) { // Create a map to store all channel edge policies. - channelMap := make(map[channelMapKey]*ChannelEdgePolicy) + channelMap := make(map[channelMapKey]*ChannelEdgePolicy1) err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error { // Skip embedded buckets. @@ -423,7 +423,7 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // for that particular channel edge routing policy will be passed into the // callback. func (c *ChannelGraph) ForEachChannel(cb func(models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { return c.db.View(func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -495,7 +495,7 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, } dbCallback := func(_ kvdb.Backend, tx kvdb.RTx, - e models.ChannelEdgeInfo, p1, p2 *ChannelEdgePolicy) error { + e models.ChannelEdgeInfo, p1, p2 *ChannelEdgePolicy1) error { var cachedInPolicy *CachedEdgePolicy if p2 != nil { @@ -570,7 +570,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, err := node.ForEachChannel(c.db, tx, func(_ kvdb.Backend, tx kvdb.RTx, e models.ChannelEdgeInfo, - p1 *ChannelEdgePolicy, p2 *ChannelEdgePolicy) error { + p1 *ChannelEdgePolicy1, p2 *ChannelEdgePolicy1) error { toNodeCallback := func() route.Vertex { return node.PubKeyBytes @@ -1873,11 +1873,11 @@ type ChannelEdge struct { // Policy1 points to the "first" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy1 *ChannelEdgePolicy + Policy1 *ChannelEdgePolicy1 // Policy2 points to the "second" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy2 *ChannelEdgePolicy + Policy2 *ChannelEdgePolicy1 } // ChanUpdatesInHorizon returns all the known channel edges which have at least @@ -2317,7 +2317,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { } func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, - edge1, edge2 *ChannelEdgePolicy) error { + edge1, edge2 *ChannelEdgePolicy1) error { // First, we'll fetch the edge update index bucket which currently // stores an entry for the channel we're about to delete. @@ -2462,7 +2462,7 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // marked with the correct lagging channel since we received an update from only // one side. func makeZombiePubkeys(info models.ChannelEdgeInfo, - e1, e2 *ChannelEdgePolicy) ([33]byte, [33]byte) { + e1, e2 *ChannelEdgePolicy1) ([33]byte, [33]byte) { var ( node1Bytes = info.Node1Bytes() @@ -2492,12 +2492,12 @@ func makeZombiePubkeys(info models.ChannelEdgeInfo, // UpdateEdgePolicy updates the edge routing policy for a single directed edge // within the database for the referenced channel. The `flags` attribute within -// the ChannelEdgePolicy determines which of the directed edges are being +// the ChannelEdgePolicy1 determines which of the directed edges are being // updated. If the flag is 1, then the first node's information is being // updated, otherwise it's the second node's information. The node ordering is // determined by the lexicographical ordering of the identity public keys of the // nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy, +func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy1, op ...batch.SchedulerOption) error { var ( @@ -2545,7 +2545,7 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy, return c.chanScheduler.Execute(r) } -func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) { +func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy1, isUpdate1 bool) { // If an entry for this channel is found in reject cache, we'll modify // the entry with the updated timestamp for the direction that was just // written. If the edge doesn't exist, we'll load the cache entry lazily @@ -2577,7 +2577,7 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) { // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, +func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy1, graphCache *GraphCache) (bool, error) { edges := tx.ReadWriteBucket(edgeBucket) @@ -2785,7 +2785,7 @@ func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, nodeIsPublic := false errDone := errors.New("done") err := l.ForEachChannel(db, tx, func(_ kvdb.Backend, _ kvdb.RTx, - info models.ChannelEdgeInfo, _, _ *ChannelEdgePolicy) error { + info models.ChannelEdgeInfo, _, _ *ChannelEdgePolicy1) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is @@ -2896,7 +2896,7 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], db, cb) } @@ -2957,7 +2957,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { traversal := func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -3068,7 +3068,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // traversal. func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { nodePub := l.PubKeyBytes[:] @@ -3079,7 +3079,7 @@ func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, // unique attributes. Once an authenticated channel announcement has been // processed on the network, then an instance of ChannelEdgeInfo1 encapsulating // the channels attributes is stored. The other portions relevant to routing -// policy of a channel are stored within a ChannelEdgePolicy for each direction +// policy of a channel are stored within a ChannelEdgePolicy1 for each direction // of the channel. type ChannelEdgeInfo1 struct { // ChannelID is the unique channel ID for the channel. The first 3 @@ -3589,12 +3589,12 @@ func (c *ChannelAuthProof1) IsEmpty() bool { len(c.BitcoinSig2Bytes) == 0 } -// ChannelEdgePolicy represents a *directed* edge within the channel graph. For +// ChannelEdgePolicy1 represents a *directed* edge within the channel graph. For // each channel in the database, there are two distinct edges: one for each // possible direction of travel along the channel. The edges themselves hold // information concerning fees, and minimum time-lock information which is // utilized during path finding. -type ChannelEdgePolicy struct { +type ChannelEdgePolicy1 struct { // SigBytes is the raw bytes of the signature of the channel edge // policy. We'll only parse these if the caller needs to access the // signature for validation purposes. Do not set SigBytes directly, but @@ -3660,7 +3660,7 @@ type ChannelEdgePolicy struct { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelEdgePolicy) Signature() (*ecdsa.Signature, error) { +func (c *ChannelEdgePolicy1) Signature() (*ecdsa.Signature, error) { if c.sig != nil { return c.sig, nil } @@ -3677,20 +3677,20 @@ func (c *ChannelEdgePolicy) Signature() (*ecdsa.Signature, error) { // SetSigBytes updates the signature and invalidates the cached parsed // signature. -func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { +func (c *ChannelEdgePolicy1) SetSigBytes(sig []byte) { c.SigBytes = sig c.sig = nil } // IsDisabled determines whether the edge has the disabled bit set. -func (c *ChannelEdgePolicy) IsDisabled() bool { +func (c *ChannelEdgePolicy1) IsDisabled() bool { return c.ChannelFlags.IsDisabled() } // ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over // the passed active payment channel. This value is currently computed as // specified in BOLT07, but will likely change in the near future. -func (c *ChannelEdgePolicy) ComputeFee( +func (c *ChannelEdgePolicy1) ComputeFee( amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts @@ -3703,7 +3703,7 @@ func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { // ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming // amount. -func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( +func (c *ChannelEdgePolicy1) ComputeFeeFromIncoming( incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { return incomingAmt - divideCeil( @@ -3718,12 +3718,12 @@ func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (models.ChannelEdgeInfo, *ChannelEdgePolicy1, *ChannelEdgePolicy1, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy + policy1 *ChannelEdgePolicy1 + policy2 *ChannelEdgePolicy1 ) err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -3801,15 +3801,15 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // routing policies for the channel in either direction. // // ErrZombieEdge an be returned if the edge is currently marked as a zombie -// within the database. In this case, the ChannelEdgePolicy's will be nil, and +// within the database. In this case, the ChannelEdgePolicy1's will be nil, and // the ChannelEdgeInfo will only include the public keys of each node. func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (models.ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { +) (models.ChannelEdgeInfo, *ChannelEdgePolicy1, *ChannelEdgePolicy1, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy + policy1 *ChannelEdgePolicy1 + policy2 *ChannelEdgePolicy1 channelID [8]byte ) @@ -4674,7 +4674,7 @@ func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { return &edgeInfo, nil } -func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy, +func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy1, from, to []byte) error { var edgeKey [33 + 8]byte @@ -4746,7 +4746,7 @@ func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy, } // updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex -// bucket by either add a new disabled ChannelEdgePolicy or remove an existing +// bucket by either add a new disabled ChannelEdgePolicy1 or remove an existing // one. // The direction represents the direction of the edge and disabled is used for // deciding whether to remove or add an entry to the bucket. @@ -4795,7 +4795,7 @@ func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64, } func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, - nodePub []byte, nodes kvdb.RBucket) (*ChannelEdgePolicy, error) { + nodePub []byte, nodes kvdb.RBucket) (*ChannelEdgePolicy1, error) { var edgeKey [33 + 8]byte copy(edgeKey[:], nodePub) @@ -4828,8 +4828,8 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, - nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicy, - *ChannelEdgePolicy, error) { + nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicy1, + *ChannelEdgePolicy1, error) { edgeInfoBytes := edgeIndex.Get(chanID) if edgeInfoBytes == nil { @@ -4861,7 +4861,7 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return edge1, edge2, nil } -func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, +func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy1, to []byte) error { err := wire.WriteVarBytes(w, 0, edge.SigBytes) @@ -4929,7 +4929,7 @@ func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, } func deserializeChanEdgePolicy(r io.Reader, - nodes kvdb.RBucket) (*ChannelEdgePolicy, error) { + nodes kvdb.RBucket) (*ChannelEdgePolicy1, error) { // Deserialize the policy. Note that in case an optional field is not // found, both an error and a populated policy object are returned. @@ -4951,8 +4951,8 @@ func deserializeChanEdgePolicy(r io.Reader, return edge, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy, error) { - edge := &ChannelEdgePolicy{} +func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { + edge := &ChannelEdgePolicy1{} var err error edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index e4cd9e5c53..c52ba8f1d6 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -29,11 +29,11 @@ type GraphCacheNode interface { // to the caller. ForEachChannel(kvdb.Backend, kvdb.RTx, func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error } // CachedEdgePolicy is a struct that only caches the information of a -// ChannelEdgePolicy that we actually use for pathfinding and therefore need to +// ChannelEdgePolicy1 that we actually use for pathfinding and therefore need to // store in the cache. type CachedEdgePolicy struct { // ChannelID is the unique channel ID for the channel. The first 3 @@ -105,7 +105,7 @@ func (c *CachedEdgePolicy) ComputeFeeFromIncoming( } // NewCachedPolicy turns a full policy into a minimal one that can be cached. -func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { +func NewCachedPolicy(policy *ChannelEdgePolicy1) *CachedEdgePolicy { return &CachedEdgePolicy{ ChannelID: policy.ChannelID, MessageFlags: policy.MessageFlags, @@ -224,8 +224,8 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) return node.ForEachChannel(nil, tx, func(_ kvdb.Backend, tx kvdb.RTx, - info models.ChannelEdgeInfo, outPolicy *ChannelEdgePolicy, - inPolicy *ChannelEdgePolicy) error { + info models.ChannelEdgeInfo, outPolicy *ChannelEdgePolicy1, + inPolicy *ChannelEdgePolicy1) error { c.AddChannel(info, outPolicy, inPolicy) @@ -238,7 +238,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { // and policy flags automatically. The policy will be set as the outgoing policy // on one node and the incoming policy on the peer's side. func (c *GraphCache) AddChannel(info models.ChannelEdgeInfo, - policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { + policy1 *ChannelEdgePolicy1, policy2 *ChannelEdgePolicy1) { if info == nil { return @@ -300,7 +300,7 @@ func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { // of the from and to node is not strictly important. But we assume that a // channel edge was added beforehand so that the directed channel struct already // exists in the cache. -func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, +func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy1, fromNode, toNode route.Vertex, edge1 bool) { c.mtx.Lock() diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 7436451ec8..a0ab701940 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -30,8 +30,8 @@ type node struct { features *lnwire.FeatureVector edgeInfos []*ChannelEdgeInfo1 - outPolicies []*ChannelEdgePolicy - inPolicies []*ChannelEdgePolicy + outPolicies []*ChannelEdgePolicy1 + inPolicies []*ChannelEdgePolicy1 } func (n *node) PubKey() route.Vertex { @@ -43,7 +43,7 @@ func (n *node) Features() *lnwire.FeatureVector { func (n *node) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { for idx := range n.edgeInfos { err := cb( @@ -71,7 +71,7 @@ func TestGraphCacheAddNode(t *testing.T) { channelFlagA, channelFlagB = 1, 0 } - outPolicy1 := &ChannelEdgePolicy{ + outPolicy1 := &ChannelEdgePolicy1{ ChannelID: 1000, ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), Node: &LightningNode{ @@ -79,7 +79,7 @@ func TestGraphCacheAddNode(t *testing.T) { Features: lnwire.EmptyFeatureVector(), }, } - inPolicy1 := &ChannelEdgePolicy{ + inPolicy1 := &ChannelEdgePolicy1{ ChannelID: 1000, ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), Node: &LightningNode{ @@ -97,8 +97,8 @@ func TestGraphCacheAddNode(t *testing.T) { NodeKey2Bytes: pubKey2, Capacity: 500, }}, - outPolicies: []*ChannelEdgePolicy{outPolicy1}, - inPolicies: []*ChannelEdgePolicy{inPolicy1}, + outPolicies: []*ChannelEdgePolicy1{outPolicy1}, + inPolicies: []*ChannelEdgePolicy1{inPolicy1}, } cache := NewGraphCache(10) require.NoError(t, cache.AddNode(nil, node)) @@ -145,7 +145,7 @@ func TestGraphCacheAddNode(t *testing.T) { runTest(pubKey2, pubKey1) } -func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy, +func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy1, cached *CachedEdgePolicy) { require.Equal(t, original.ChannelID, cached.ChannelID) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 22a570b62a..b470ba4b25 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -629,7 +629,7 @@ func assertEdgeInfo1Equal(t *testing.T, e1 *ChannelEdgeInfo1, } func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( - *ChannelEdgeInfo1, *ChannelEdgePolicy, *ChannelEdgePolicy) { + *ChannelEdgeInfo1, *ChannelEdgePolicy1, *ChannelEdgePolicy1) { var ( firstNode *LightningNode @@ -671,7 +671,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) - edge1 := &ChannelEdgePolicy{ + edge1 := &ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: time.Unix(433453, 0), @@ -685,7 +685,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( Node: secondNode, ExtraOpaqueData: []byte("new unknown feature2"), } - edge2 := &ChannelEdgePolicy{ + edge2 := &ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: time.Unix(124234, 0), @@ -913,7 +913,7 @@ func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { } func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, - e models.ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + e models.ChannelEdgeInfo, p *ChannelEdgePolicy1, policy1 bool) { var ( node1Bytes = e.Node1Bytes() @@ -995,16 +995,16 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, } } -func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy { +func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy1 { update := prand.Int63() return newEdgePolicy(chanID, db, update) } func newEdgePolicy(chanID uint64, db kvdb.Backend, - updateTime int64) *ChannelEdgePolicy { + updateTime int64) *ChannelEdgePolicy1 { - return &ChannelEdgePolicy{ + return &ChannelEdgePolicy1{ ChannelID: chanID, LastUpdate: time.Unix(updateTime, 0), MessageFlags: 1, @@ -1062,7 +1062,7 @@ func TestGraphTraversal(t *testing.T) { // again if the map is empty that indicates that all edges have // properly been reached. err = graph.ForEachChannel(func(ei models.ChannelEdgeInfo, - _ *ChannelEdgePolicy, _ *ChannelEdgePolicy) error { + _ *ChannelEdgePolicy1, _ *ChannelEdgePolicy1) error { delete(chanIndex, ei.GetChanID()) return nil @@ -1076,7 +1076,7 @@ func TestGraphTraversal(t *testing.T) { firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, _ models.ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy) error { + outEdge, inEdge *ChannelEdgePolicy1) error { // All channels between first and second node should // have fully (both sides) specified policies. @@ -1159,8 +1159,8 @@ func TestGraphTraversalCacheable(t *testing.T) { err := node.ForEachChannel( graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, info models.ChannelEdgeInfo, - _ *ChannelEdgePolicy, - _ *ChannelEdgePolicy) error { + _ *ChannelEdgePolicy1, + _ *ChannelEdgePolicy1) error { delete(chanIndex, info.GetChanID()) @@ -1344,7 +1344,7 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 if err := graph.ForEachChannel(func(models.ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) error { + *ChannelEdgePolicy1, *ChannelEdgePolicy1) error { numChans++ return nil @@ -2311,7 +2311,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, _ models.ChannelEdgeInfo, outEdge, - inEdge *ChannelEdgePolicy) error { + inEdge *ChannelEdgePolicy1) error { if !expectedOut && outEdge != nil { t.Fatalf("Expected no outgoing policy") @@ -2904,7 +2904,7 @@ func TestDisabledChannelIDs(t *testing.T) { } } -// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in +// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy1 in // the DB that indicates that it should support the htlc_maximum_value_msat // field, but it is not part of the opaque data, then we'll handle it as it is // unknown. It also checks that we are correctly able to overwrite it when we @@ -3180,7 +3180,7 @@ func compareNodes(a, b *LightningNode) error { // compareEdgePolicies is used to compare two ChannelEdgePolices using // compareNodes, so as to exclude comparisons of the Nodes' Features struct. -func compareEdgePolicies(a, b *ChannelEdgePolicy) error { +func compareEdgePolicies(a, b *ChannelEdgePolicy1) error { if a.ChannelID != b.ChannelID { return fmt.Errorf("ChannelID doesn't match: expected %v, "+ "got %v", a.ChannelID, b.ChannelID) @@ -3270,7 +3270,7 @@ func TestLightningNodeSigVerification(t *testing.T) { // TestComputeFee tests fee calculation based on both in- and outgoing amt. func TestComputeFee(t *testing.T) { var ( - policy = ChannelEdgePolicy{ + policy = ChannelEdgePolicy1{ FeeBaseMSat: 10000, FeeProportionalMillionths: 30000, } @@ -3403,7 +3403,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { errTimeout := errors.New("timeout adding batched channel") - updates := []*ChannelEdgePolicy{edge1, edge2} + updates := []*ChannelEdgePolicy1{edge1, edge2} errChan := make(chan error, len(updates)) @@ -3411,7 +3411,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { var wg sync.WaitGroup for _, update := range updates { wg.Add(1) - go func(update *ChannelEdgePolicy) { + go func(update *ChannelEdgePolicy1) { defer wg.Done() select { @@ -3463,8 +3463,8 @@ func BenchmarkForEachChannel(b *testing.B) { graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, info models.ChannelEdgeInfo, - policy *ChannelEdgePolicy, - policy2 *ChannelEdgePolicy) error { + policy *ChannelEdgePolicy1, + policy2 *ChannelEdgePolicy1) error { // We need to do something with // the data here, otherwise the diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 46f3c61e72..209d58fadb 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -531,7 +531,7 @@ type EdgeWithInfo struct { Info models.ChannelEdgeInfo // Edge describes the policy in one direction of the channel. - Edge *channeldb.ChannelEdgePolicy + Edge *channeldb.ChannelEdgePolicy1 } // PropagateChanPolicyUpdate signals the AuthenticatedGossiper to perform the @@ -1581,7 +1581,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // within the prune interval or re-broadcast interval. type updateTuple struct { info models.ChannelEdgeInfo - edge *channeldb.ChannelEdgePolicy + edge *channeldb.ChannelEdgePolicy1 } var ( @@ -1590,7 +1590,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { ) err := d.cfg.Router.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy) error { + edge *channeldb.ChannelEdgePolicy1) error { // If there's no auth proof attached to this edge, it means // that it is a private channel not meant to be announced to @@ -2131,7 +2131,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // Otherwise, we'll retrieve the correct policy that we // currently have stored within our graph to check if this // message is stale by comparing its timestamp. - var p *channeldb.ChannelEdgePolicy + var p *channeldb.ChannelEdgePolicy1 if msg.ChannelFlags&lnwire.ChanUpdateDirection == 0 { p = p1 } else { @@ -2157,7 +2157,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + edge *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, error) { // Parse the unsigned edge into a channel update. @@ -2268,7 +2268,7 @@ func (d *AuthenticatedGossiper) SyncManager() *SyncManager { // keep-alive update based on the previous channel update processed for the same // direction. func IsKeepAliveUpdate(update *lnwire.ChannelUpdate1, - prev *channeldb.ChannelEdgePolicy) bool { + prev *channeldb.ChannelEdgePolicy1) bool { // Both updates should be from the same direction. if update.ChannelFlags&lnwire.ChanUpdateDirection != @@ -2830,7 +2830,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // being updated. var ( pubKey *btcec.PublicKey - edgeToUpdate *channeldb.ChannelEdgePolicy + edgeToUpdate *channeldb.ChannelEdgePolicy1 ) direction := upd.ChannelFlags & lnwire.ChanUpdateDirection switch direction { @@ -2921,7 +2921,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // different alias. This might mean that SigBytes is incorrect as it // signs a different SCID than the database SCID, but since there will // only be a difference if AuthProof == nil, this is fine. - update := &channeldb.ChannelEdgePolicy{ + update := &channeldb.ChannelEdgePolicy1{ SigBytes: upd.Signature.ToSignatureBytes(), ChannelID: chanInfo.GetChanID(), LastUpdate: timestamp, diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 8b3a81141b..305d2eb0cd 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -93,7 +93,7 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]channeldb.ChannelEdgePolicy + edges map[uint64][]channeldb.ChannelEdgePolicy1 zombies map[uint64][][33]byte chansToReject map[uint64]struct{} } @@ -102,7 +102,7 @@ func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, infos: make(map[uint64]models.ChannelEdgeInfo), - edges: make(map[uint64][]channeldb.ChannelEdgePolicy), + edges: make(map[uint64][]channeldb.ChannelEdgePolicy1), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), } @@ -146,14 +146,14 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.chansToReject[chanID] = struct{}{} } -func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy, +func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy1, _ ...batch.SchedulerOption) error { r.mu.Lock() defer r.mu.Unlock() if len(r.edges[edge.ChannelID]) == 0 { - r.edges[edge.ChannelID] = make([]channeldb.ChannelEdgePolicy, 2) + r.edges[edge.ChannelID] = make([]channeldb.ChannelEdgePolicy1, 2) } if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { @@ -198,7 +198,7 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) } func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i models.ChannelEdgeInfo, c *channeldb.ChannelEdgePolicy) error) error { + i models.ChannelEdgeInfo, c *channeldb.ChannelEdgePolicy1) error) error { r.mu.Lock() defer r.mu.Unlock() @@ -229,14 +229,14 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, } func (r *mockGraphSource) ForEachChannel(_ func(chanInfo models.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy) error) error { + e1, e2 *channeldb.ChannelEdgePolicy1) error) error { return nil } func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, + *channeldb.ChannelEdgePolicy1, error) { r.mu.Lock() defer r.mu.Unlock() @@ -262,13 +262,13 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return chanInfoCP, nil, nil, nil } - var edge1 *channeldb.ChannelEdgePolicy - if !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy{}) { + var edge1 *channeldb.ChannelEdgePolicy1 + if !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy1{}) { edge1 = &edges[0] } - var edge2 *channeldb.ChannelEdgePolicy - if !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy{}) { + var edge2 *channeldb.ChannelEdgePolicy1 + if !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy1{}) { edge2 = &edges[1] } @@ -372,12 +372,12 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, switch { case flags&lnwire.ChanUpdateDirection == 0 && - !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy{}): + !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy1{}): return !timestamp.After(edges[0].LastUpdate) case flags&lnwire.ChanUpdateDirection == 1 && - !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy{}): + !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy1{}): return !timestamp.After(edges[1].LastUpdate) @@ -2490,7 +2490,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Fatalf("remote update was not processed") } - // Check that the ChannelEdgePolicy was added to the graph. + // Check that the ChannelEdgePolicy1 was added to the graph. chanInfo, e1, e2, err = ctx.router.GetChannelByID( batch.chanUpdAnn1.ShortChannelID, ) @@ -3456,7 +3456,7 @@ out: err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy) error { + edge *channeldb.ChannelEdgePolicy1) error { edge.TimeLockDelta = uint16(newTimeLockDelta) edgesToUpdate = append(edgesToUpdate, EdgeWithInfo{ diff --git a/funding/manager.go b/funding/manager.go index 50a83f2fe6..65b90ec7ba 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -524,7 +524,7 @@ type Config struct { // DeleteAliasEdge allows the Manager to delete an alias channel edge // from the graph. It also returns our local to-be-deleted policy. DeleteAliasEdge func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy, error) + *channeldb.ChannelEdgePolicy1, error) // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. @@ -3335,7 +3335,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, - ourPolicy *channeldb.ChannelEdgePolicy) error { + ourPolicy *channeldb.ChannelEdgePolicy1) error { chanID := lnwire.NewChanIDFromOutPoint(&completeChan.FundingOutpoint) @@ -4067,7 +4067,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, remotePubKey *btcec.PublicKey, localFundingKey *keychain.KeyDescriptor, remoteFundingKey *btcec.PublicKey, shortChanID lnwire.ShortChannelID, chanID lnwire.ChannelID, fwdMinHTLC, fwdMaxHTLC lnwire.MilliSatoshi, - ourPolicy *channeldb.ChannelEdgePolicy, + ourPolicy *channeldb.ChannelEdgePolicy1, chanType channeldb.ChannelType) (*chanAnnouncement, error) { chainHash := *f.cfg.Wallet.Cfg.NetParams.GenesisHash diff --git a/funding/manager_test.go b/funding/manager_test.go index 9b6b301118..8d7f2e95d6 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -550,7 +550,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, OpenChannelPredicate: chainedAcceptor, NotifyPendingOpenChannelEvent: evt.NotifyPendingOpenChannelEvent, DeleteAliasEdge: func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy, error) { + *channeldb.ChannelEdgePolicy1, error) { return nil, nil }, diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 8cd88c34c0..0bb27d0237 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -288,8 +288,8 @@ func (s *Server) ImportGraph(ctx context.Context, rpcEdge.ChanPoint, err) } - makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *channeldb.ChannelEdgePolicy { - policy := &channeldb.ChannelEdgePolicy{ + makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *channeldb.ChannelEdgePolicy1 { + policy := &channeldb.ChannelEdgePolicy1{ ChannelID: rpcEdge.ChannelId, LastUpdate: time.Unix( int64(rpcPolicy.LastUpdate), 0, diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 050871c9c0..7648cc7a20 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -490,7 +490,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // chanCanBeHopHint returns true if the target channel is eligible to be a hop // hint. func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( - *channeldb.ChannelEdgePolicy, bool) { + *channeldb.ChannelEdgePolicy1, bool) { // Since we're only interested in our private channels, we'll skip // public ones. @@ -546,7 +546,7 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( // Now, we'll need to determine which is the correct policy for HTLCs // being sent from the remote node. var ( - remotePolicy *channeldb.ChannelEdgePolicy + remotePolicy *channeldb.ChannelEdgePolicy1 node1Bytes = info.Node1Bytes() ) if bytes.Equal(remotePub[:], node1Bytes[:]) { @@ -606,9 +606,9 @@ func newHopHintInfo(c *channeldb.OpenChannel, isActive bool) *HopHintInfo { } // newHopHint returns a new hop hint using the relevant data from a hopHintInfo -// and a ChannelEdgePolicy. +// and a ChannelEdgePolicy1. func newHopHint(hopHintInfo *HopHintInfo, - chanPolicy *channeldb.ChannelEdgePolicy) zpay32.HopHint { + chanPolicy *channeldb.ChannelEdgePolicy1) zpay32.HopHint { return zpay32.HopHint{ NodeID: hopHintInfo.RemotePubkey, @@ -632,7 +632,7 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. FetchChannelEdgesByID func(chanID uint64) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, error) // GetAlias allows the peer's alias SCID to be retrieved for private diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 223a24495c..af038b9415 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -52,8 +52,8 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, + *channeldb.ChannelEdgePolicy1, error) { args := h.Mock.Called(chanID) @@ -71,8 +71,8 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( "ChannelEdgeInfo impl received: %T", args.Get(0)) } - policy1 := args.Get(1).(*channeldb.ChannelEdgePolicy) - policy2 := args.Get(2).(*channeldb.ChannelEdgePolicy) + policy1 := args.Get(1).(*channeldb.ChannelEdgePolicy1) + policy2 := args.Get(2).(*channeldb.ChannelEdgePolicy1) return edgeInfo, policy1, policy2, err } @@ -222,8 +222,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) h.Mock.On( @@ -260,8 +260,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) alias := lnwire.ShortChannelID{TxPosition: 5} h.Mock.On( @@ -303,12 +303,12 @@ var shouldIncludeChannelTestCases = []struct { &channeldb.ChannelEdgeInfo1{ NodeKey1Bytes: selectedPolicy, }, - &channeldb.ChannelEdgePolicy{ + &channeldb.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, }, - &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, @@ -349,8 +349,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{ + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -394,8 +394,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{ + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -561,8 +561,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, @@ -611,8 +611,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -662,8 +662,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, @@ -695,8 +695,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) // Prepare the mock for the second channel. @@ -712,8 +712,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -749,8 +749,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) // Prepare the mock for the second channel. @@ -766,8 +766,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -804,8 +804,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy{}, - &channeldb.ChannelEdgePolicy{}, nil, + &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 1b746b2365..1d3aa6994b 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -647,7 +647,7 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // fetchLastChanUpdateByOutPoint fetches the latest policy for our direction of // a channel, and crafts a new ChannelUpdate1 with this policy. Returns an error -// in case our ChannelEdgePolicy is not found in the database. Also returns if +// in case our ChannelEdgePolicy1 is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( *lnwire.ChannelUpdate1, bool, error) { diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 9ff359f2cc..3bed4b0a51 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -67,8 +67,8 @@ func createChannel(t *testing.T) *channeldb.OpenChannel { // update will be created with the disabled bit set if startEnabled is false. func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, pubkey *btcec.PublicKey, startEnabled bool) ( - *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) { + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy1, + *channeldb.ChannelEdgePolicy1) { var ( pubkey1 [33]byte @@ -105,13 +105,13 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, NodeKey1Bytes: pubkey1, NodeKey2Bytes: pubkey2, }, - &channeldb.ChannelEdgePolicy{ + &channeldb.ChannelEdgePolicy1{ ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir1, LastUpdate: time.Now(), SigBytes: testSigBytes, }, - &channeldb.ChannelEdgePolicy{ + &channeldb.ChannelEdgePolicy1{ ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir2, LastUpdate: time.Now(), @@ -123,8 +123,8 @@ type mockGraph struct { mu sync.Mutex channels []*channeldb.OpenChannel chanInfos map[wire.OutPoint]*channeldb.ChannelEdgeInfo1 - chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicy - chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy + chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicy1 + chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy1 sidToCid map[lnwire.ShortChannelID]wire.OutPoint updates chan *lnwire.ChannelUpdate1 @@ -136,8 +136,8 @@ func newMockGraph(t *testing.T, numChannels int, g := &mockGraph{ channels: make([]*channeldb.OpenChannel, 0, numChannels), chanInfos: make(map[wire.OutPoint]*channeldb.ChannelEdgeInfo1), - chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), - chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), + chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy1), + chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy1), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), } @@ -162,7 +162,7 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { func (g *mockGraph) FetchChannelEdgesByOutpoint( op *wire.OutPoint) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { + *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, error) { g.mu.Lock() defer g.mu.Unlock() @@ -211,7 +211,7 @@ func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, timestamp := time.Unix(int64(update.Timestamp), 0) - policy := &channeldb.ChannelEdgePolicy{ + policy := &channeldb.ChannelEdgePolicy1{ ChannelID: update.ShortChannelID.ToUint64(), ChannelFlags: update.ChannelFlags, LastUpdate: timestamp, @@ -250,7 +250,7 @@ func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, info *channeldb.ChannelEdgeInfo1, - pol1, pol2 *channeldb.ChannelEdgePolicy) { + pol1, pol2 *channeldb.ChannelEdgePolicy1) { g.mu.Lock() defer g.mu.Unlock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 1da5f610b9..db870d0568 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -16,7 +16,7 @@ import ( // peer's initial routing table upon connect. func CreateChanAnnouncement(chanProof models.ChannelAuthProof, chanInfo models.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + e1, e2 *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { switch proof := chanProof.(type) { @@ -38,7 +38,7 @@ func CreateChanAnnouncement(chanProof models.ChannelAuthProof, func createChanAnnouncement1(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, - e1, e2 *channeldb.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, + e1, e2 *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { // First, using the parameters of the channel, along with the channel diff --git a/netann/channel_update.go b/netann/channel_update.go index 73f0f9996a..7302fdf561 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -87,11 +87,11 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, info models.ChannelEdgeInfo, - policies ...*channeldb.ChannelEdgePolicy) ( + policies ...*channeldb.ChannelEdgePolicy1) ( *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. - owner := func(edge *channeldb.ChannelEdgePolicy) []byte { + owner := func(edge *channeldb.ChannelEdgePolicy1) []byte { var pubKey *btcec.PublicKey if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { pubKey, _ = info.NodeKey1() @@ -120,7 +120,7 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate1 from the // given edge info and policy. func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, - policy *channeldb.ChannelEdgePolicy) *lnwire.ChannelUpdate1 { + policy *channeldb.ChannelEdgePolicy1) *lnwire.ChannelUpdate1 { return &lnwire.ChannelUpdate1{ ChainHash: chainHash, @@ -140,7 +140,7 @@ func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate1 from the given edge // info and policy. func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, - policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) { + policy *channeldb.ChannelEdgePolicy1) (*lnwire.ChannelUpdate1, error) { update := UnsignedChannelUpdateFromEdge(info.GetChainHash(), policy) diff --git a/netann/interface.go b/netann/interface.go index cdeefc0987..91b1eb7196 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -20,5 +20,5 @@ type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. FetchChannelEdgesByOutpoint(*wire.OutPoint) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) + *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, error) } diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 221ed69cb0..39e799043e 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -32,7 +32,7 @@ type Manager struct { // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy) error) error + *channeldb.ChannelEdgePolicy1) error) error // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. @@ -74,7 +74,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy) error { + edge *channeldb.ChannelEdgePolicy1) error { chanPoint := info.GetChanPoint() @@ -174,7 +174,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // updateEdge updates the given edge with the new schema. func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint, - edge *channeldb.ChannelEdgePolicy, + edge *channeldb.ChannelEdgePolicy1, newSchema routing.ChannelPolicy) error { // Update forwarding fee scheme and required time lock delta. diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index 4409dff620..9bd9528d0d 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -44,7 +44,7 @@ func TestManager(t *testing.T) { MaxHTLC: 5000, } - currentPolicy := channeldb.ChannelEdgePolicy{ + currentPolicy := channeldb.ChannelEdgePolicy1{ MinHTLC: minHTLC, MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, } @@ -108,7 +108,7 @@ func TestManager(t *testing.T) { forAllOutgoingChannels := func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy) error) error { + *channeldb.ChannelEdgePolicy1) error) error { for _, c := range channelSet { if err := cb(nil, c.edgeInfo, ¤tPolicy); err != nil { @@ -152,7 +152,7 @@ func TestManager(t *testing.T) { tests := []struct { name string - currentPolicy channeldb.ChannelEdgePolicy + currentPolicy channeldb.ChannelEdgePolicy1 newPolicy routing.ChannelPolicy channelSet []channel specifiedChanPoints []wire.OutPoint diff --git a/routing/notifications.go b/routing/notifications.go index b0ff5c79c5..6ee494299b 100644 --- a/routing/notifications.go +++ b/routing/notifications.go @@ -339,7 +339,7 @@ func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, // Any new ChannelUpdateAnnouncements will generate a corresponding // ChannelEdgeUpdate notification. - case *channeldb.ChannelEdgePolicy: + case *channeldb.ChannelEdgePolicy1: // We'll need to fetch the edge's information from the database // in order to get the information concerning which nodes are // being connected. diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 63f03923ff..21fa57fa97 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -74,9 +74,9 @@ func createTestNode() (*channeldb.LightningNode, error) { } func randEdgePolicy(chanID *lnwire.ShortChannelID, - node *channeldb.LightningNode) *channeldb.ChannelEdgePolicy { + node *channeldb.LightningNode) *channeldb.ChannelEdgePolicy1 { - return &channeldb.ChannelEdgePolicy{ + return &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Unix(int64(prand.Int31()), 0), @@ -455,7 +455,7 @@ func TestEdgeUpdateNotification(t *testing.T) { } assertEdgeCorrect := func(t *testing.T, edgeUpdate *ChannelEdgeUpdate, - edgeAnn *channeldb.ChannelEdgePolicy) { + edgeAnn *channeldb.ChannelEdgePolicy1) { if edgeUpdate.ChanID != edgeAnn.ChannelID { t.Fatalf("channel ID of edge doesn't match: "+ "expected %v, got %v", chanID.ToUint64(), edgeUpdate.ChanID) diff --git a/routing/pathfind.go b/routing/pathfind.go index d007be05e7..c67239c77d 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -83,7 +83,7 @@ var ( ) // edgePolicyWithSource is a helper struct to keep track of the source node -// of a channel edge. ChannelEdgePolicy only contains to destination node +// of a channel edge. ChannelEdgePolicy1 only contains to destination node // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex @@ -1029,7 +1029,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // destination features were provided. This is fine though, since our // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within - // findPath, and avoid using ChannelEdgePolicy altogether. + // findPath, and avoid using ChannelEdgePolicy1 altogether. pathEdges[len(pathEdges)-1].ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 3731f3d9e0..ce10db1311 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -376,7 +376,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( targetNode = edgeInfo.NodeKey2Bytes } - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), ChannelFlags: channelFlags, @@ -689,7 +689,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, node2Features = node2.Features } - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -727,7 +727,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, node1Features = node1.Features } - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, diff --git a/routing/router.go b/routing/router.go index ec4f3caa19..c0cc8cca6d 100644 --- a/routing/router.go +++ b/routing/router.go @@ -144,7 +144,7 @@ type ChannelGraphSource interface { // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. - UpdateEdge(policy *channeldb.ChannelEdgePolicy, + UpdateEdge(policy *channeldb.ChannelEdgePolicy1, op ...batch.SchedulerOption) error // IsStaleNode returns true if the graph source has a node announcement @@ -176,7 +176,7 @@ type ChannelGraphSource interface { // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, c models.ChannelEdgeInfo, - e *channeldb.ChannelEdgePolicy) error) error + e *channeldb.ChannelEdgePolicy1) error) error // CurrentBlockHeight returns the block height from POV of the router // subsystem. @@ -184,8 +184,8 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy, error) + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, + *channeldb.ChannelEdgePolicy1, error) // FetchLightningNode attempts to look up a target node by its identity // public key. channeldb.ErrGraphNodeNotFound is returned if the node @@ -472,7 +472,7 @@ type ChannelRouter struct { ntfnClientUpdates chan *topologyClientUpdate // channelEdgeMtx is a mutex we use to make sure we process only one - // ChannelEdgePolicy at a time for a given channelID, to ensure + // ChannelEdgePolicy1 at a time for a given channelID, to ensure // consistency between the various database accesses. channelEdgeMtx *multimutex.Mutex[uint64] @@ -904,7 +904,7 @@ func (r *ChannelRouter) pruneZombieChans() error { // First, we'll collect all the channels which are eligible for garbage // collection due to being zombies. filterPruneChans := func(info models.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy) error { + e1, e2 *channeldb.ChannelEdgePolicy1) error { chanID := info.GetChanID() @@ -1654,8 +1654,8 @@ func (r *ChannelRouter) processUpdate(msg interface{}, "view: %v", err) } - case *channeldb.ChannelEdgePolicy: - log.Debugf("Received ChannelEdgePolicy for channel %v", + case *channeldb.ChannelEdgePolicy1: + log.Debugf("Received ChannelEdgePolicy1 for channel %v", msg.ChannelID) // We make sure to hold the mutex for this channel ID, @@ -2616,7 +2616,7 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { return false } - err = r.UpdateEdge(&channeldb.ChannelEdgePolicy{ + err = r.UpdateEdge(&channeldb.ChannelEdgePolicy1{ SigBytes: msg.Signature.ToSignatureBytes(), ChannelID: msg.ShortChannelID.ToUint64(), LastUpdate: time.Unix(int64(msg.Timestamp), 0), @@ -2694,7 +2694,7 @@ func (r *ChannelRouter) AddEdge(edge models.ChannelEdgeInfo, // considered as not fully constructed. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) UpdateEdge(update *channeldb.ChannelEdgePolicy, +func (r *ChannelRouter) UpdateEdge(update *channeldb.ChannelEdgePolicy1, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -2735,8 +2735,8 @@ func (r *ChannelRouter) SyncedHeight() uint32 { // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, + *channeldb.ChannelEdgePolicy1, error) { return r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) } @@ -2769,11 +2769,11 @@ func (r *ChannelRouter) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy) error) error { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1) error) error { return r.selfNode.ForEachChannel(r.cfg.Graph.DB(), nil, func(_ kvdb.Backend, tx kvdb.RTx, c models.ChannelEdgeInfo, - e, _ *channeldb.ChannelEdgePolicy) error { + e, _ *channeldb.ChannelEdgePolicy1) error { if e == nil { return fmt.Errorf("channel from self node " + diff --git a/routing/router_test.go b/routing/router_test.go index 0d7f12a788..d714abb43d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1336,7 +1336,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { BitcoinKey2Bytes: pub2, AuthProof: nil, } - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -1421,7 +1421,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // We must add the edge policy to be able to use the edge for route // finding. - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -1440,7 +1440,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { } // Create edge in the other direction as well. - edgePolicy = &channeldb.ChannelEdgePolicy{ + edgePolicy = &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -1518,7 +1518,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to add edge to the channel graph: %v.", err) } - edgePolicy = &channeldb.ChannelEdgePolicy{ + edgePolicy = &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -1536,7 +1536,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to update edge policy: %v", err) } - edgePolicy = &channeldb.ChannelEdgePolicy{ + edgePolicy = &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2651,7 +2651,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { } // We'll also add two edge policies, one for each direction. - edgePolicy := &channeldb.ChannelEdgePolicy{ + edgePolicy := &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, @@ -2665,7 +2665,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { t.Fatalf("unable to update edge policy: %v", err) } - edgePolicy = &channeldb.ChannelEdgePolicy{ + edgePolicy = &channeldb.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, diff --git a/routing/validation_barrier.go b/routing/validation_barrier.go index 783aa78bfe..1353da980b 100644 --- a/routing/validation_barrier.go +++ b/routing/validation_barrier.go @@ -144,7 +144,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // These other types don't have any dependants, so no further // initialization needs to be done beyond just occupying a job slot. - case *channeldb.ChannelEdgePolicy: + case *channeldb.ChannelEdgePolicy1: return case *lnwire.ChannelUpdate1: return @@ -188,11 +188,11 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { switch msg := job.(type) { // Any ChannelUpdate1 or NodeAnnouncement jobs will need to wait on the // completion of any active ChannelAnnouncement1 jobs related to them. - case *channeldb.ChannelEdgePolicy: + case *channeldb.ChannelEdgePolicy1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) signals, ok = v.chanEdgeDependencies[shortID] - jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy, scid=%v", + jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy1, scid=%v", msg.ChannelID) case *channeldb.LightningNode: @@ -299,7 +299,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) case *lnwire.ChannelUpdate1: delete(v.chanEdgeDependencies, msg.ShortChannelID) - case *channeldb.ChannelEdgePolicy: + case *channeldb.ChannelEdgePolicy1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) delete(v.chanEdgeDependencies, shortID) diff --git a/rpcserver.go b/rpcserver.go index a4461c6c90..6b16e2b686 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5924,7 +5924,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. err = graph.ForEachChannel(func(edgeInfo models.ChannelEdgeInfo, - c1, c2 *channeldb.ChannelEdgePolicy) error { + c1, c2 *channeldb.ChannelEdgePolicy1) error { // Do not include unannounced channels unless specifically // requested. Unannounced channels include both private channels as @@ -5984,7 +5984,7 @@ func marshalExtraOpaqueData(data []byte) map[uint64][]byte { } func marshalDBEdge(edgeInfo models.ChannelEdgeInfo, - c1, c2 *channeldb.ChannelEdgePolicy) (*lnrpc.ChannelEdge, error) { + c1, c2 *channeldb.ChannelEdgePolicy1) (*lnrpc.ChannelEdge, error) { // Make sure the policies match the node they belong to. c1 should point // to the policy for NodeKey1, and c2 for NodeKey2. @@ -6040,7 +6040,7 @@ func marshalDBEdge(edgeInfo models.ChannelEdgeInfo, } func marshalDBRoutingPolicy( - policy *channeldb.ChannelEdgePolicy) *lnrpc.RoutingPolicy { + policy *channeldb.ChannelEdgePolicy1) *lnrpc.RoutingPolicy { disabled := policy.ChannelFlags&lnwire.ChanUpdateDisabled != 0 @@ -6175,7 +6175,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, edge models.ChannelEdgeInfo, - c1, c2 *channeldb.ChannelEdgePolicy) error { + c1, c2 *channeldb.ChannelEdgePolicy1) error { numChannels++ totalCapacity += edge.GetCapacity() @@ -6792,7 +6792,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, err = selfNode.ForEachChannel(channelGraph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, chanInfo models.ChannelEdgeInfo, - edgePolicy, _ *channeldb.ChannelEdgePolicy) error { + edgePolicy, _ *channeldb.ChannelEdgePolicy1) error { // Self node should always have policies for its // channels. diff --git a/server.go b/server.go index fe4f2ef478..778e89ee8f 100644 --- a/server.go +++ b/server.go @@ -1235,7 +1235,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Wrap the DeleteChannelEdges method so that the funding manager can // use it without depending on several layers of indirection. deleteAliasEdge := func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy, error) { + *channeldb.ChannelEdgePolicy1, error) { info, e1, e2, err := s.graphDB.FetchChannelEdgesByID( scid.ToUint64(), @@ -1254,7 +1254,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, var ourKey [33]byte copy(ourKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) - var ourPolicy *channeldb.ChannelEdgePolicy + var ourPolicy *channeldb.ChannelEdgePolicy1 if info != nil && info.Node1Bytes() == ourKey { ourPolicy = e1 } else { @@ -3098,7 +3098,7 @@ func (s *server) establishPersistentConnections() error { err = sourceNode.ForEachChannel(s.graphDB.DB(), nil, func( db kvdb.Backend, tx kvdb.RTx, chanInfo models.ChannelEdgeInfo, - policy, _ *channeldb.ChannelEdgePolicy) error { + policy, _ *channeldb.ChannelEdgePolicy1) error { // If the remote party has announced the channel to us, but we // haven't yet, then we won't have a policy. However, we don't From 9eea29221ccb572914e59d70d709eebe3850ec50 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Tue, 17 Oct 2023 16:14:39 +0200 Subject: [PATCH 24/33] temp: add ChannelUpdate interface --- lnwire/channel_update.go | 62 ++++++++++++++++++++++++++++++++++++++++ lnwire/interfaces.go | 24 +++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 95f946da07..dc6047cb12 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -6,6 +6,7 @@ import ( "io" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" ) // ChanUpdateMsgFlags is a bitfield that signals whether optional fields are @@ -279,3 +280,64 @@ func (a *ChannelUpdate1) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +func (a *ChannelUpdate1) GetSignature() Sig { + return a.Signature +} + +func (a *ChannelUpdate1) GetBaseFee() MilliSatoshi { + return MilliSatoshi(a.BaseFee) +} + +func (a *ChannelUpdate1) GetFeeRate() MilliSatoshi { + return MilliSatoshi(a.FeeRate) +} + +func (a *ChannelUpdate1) GetTimeLock() uint16 { + return a.TimeLockDelta +} + +func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) { + a.ShortChannelID = scid +} + +func (a *ChannelUpdate1) SetSig(sig input.Signature) error { + s, err := NewSigFromSignature(sig) + if err != nil { + return err + } + + a.Signature = s + + return nil +} + +func (a *ChannelUpdate1) IsDisabled() bool { + return a.ChannelFlags&ChanUpdateDisabled == ChanUpdateDisabled +} + +func (a *ChannelUpdate1) GetChainHash() chainhash.Hash { + return a.ChainHash +} + +func (a *ChannelUpdate1) SetDisabled(disabled bool) { + if disabled { + // Set the bit responsible for marking a channel as + // disabled. + a.ChannelFlags |= ChanUpdateDisabled + } else { + // Clear the bit responsible for marking a channel as + // disabled. + a.ChannelFlags &= ^ChanUpdateDisabled + } +} + +func (a *ChannelUpdate1) IsNode1() bool { + return a.ChannelFlags&ChanUpdateDirection == 0 +} + +func (a *ChannelUpdate1) SCID() ShortChannelID { + return a.ShortChannelID +} + +var _ ChannelUpdate = (*ChannelUpdate1)(nil) diff --git a/lnwire/interfaces.go b/lnwire/interfaces.go index d50fad3862..f6b6c1d98a 100644 --- a/lnwire/interfaces.go +++ b/lnwire/interfaces.go @@ -1,6 +1,11 @@ package lnwire -import "github.com/btcsuite/btcd/chaincfg/chainhash" +import ( + "io" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" +) // ChannelAnnouncement is an interface that must be satisfied by any message // used to announce and prove the existence of a channel. @@ -22,3 +27,20 @@ type ChannelAnnouncement interface { Message } + +type ChannelUpdate interface { //nolint:interfacebloat + Decode(r io.Reader, pver uint32) error + SCID() ShortChannelID + IsNode1() bool + SetDisabled(bool) + IsDisabled() bool + GetChainHash() chainhash.Hash + SetSig(signature input.Signature) error + SetSCID(scid ShortChannelID) + GetTimeLock() uint16 + GetBaseFee() MilliSatoshi + GetFeeRate() MilliSatoshi + GetSignature() Sig + + Message +} From f03e646d2588383a50f7ca1df9dfaa5ffd3d6c8c Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Wed, 18 Oct 2023 15:08:20 +0200 Subject: [PATCH 25/33] multi: use ChannelUpdate interface for failure messages --- channeldb/graph.go | 24 +++++ htlcswitch/interceptable_switch.go | 2 +- htlcswitch/link.go | 8 +- htlcswitch/switch.go | 4 +- htlcswitch/switch_test.go | 12 ++- lnrpc/routerrpc/router_backend.go | 87 +++++++++++----- lnwire/onion_error.go | 143 +++++++++++++++++--------- lnwire/onion_error_test.go | 18 ++-- routing/ann_validation.go | 26 ++++- routing/missioncontrol_test.go | 6 +- routing/mock_test.go | 4 +- routing/payment_lifecycle.go | 5 +- routing/payment_session.go | 10 +- routing/result_interpretation_test.go | 8 +- routing/router.go | 50 ++++----- routing/router_test.go | 17 ++- 16 files changed, 271 insertions(+), 153 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 4b919350e0..fd8b1e2504 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -3712,6 +3712,30 @@ func (c *ChannelEdgePolicy1) ComputeFeeFromIncoming( ) } +func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) (*ChannelEdgePolicy1, + error) { + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + //nolint:lll + return &ChannelEdgePolicy1{ + SigBytes: upd.Signature.ToSignatureBytes(), + ChannelID: upd.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(upd.Timestamp), 0), + MessageFlags: upd.MessageFlags, + ChannelFlags: upd.ChannelFlags, + TimeLockDelta: upd.TimeLockDelta, + MinHTLC: upd.HtlcMinimumMsat, + MaxHTLC: upd.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + }, nil + default: + return nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) + } +} + // FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for // the channel identified by the funding outpoint. If the channel can't be // found, then ErrEdgeNotFound is returned. A struct which houses the general diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 09646bf2d0..aa4294de08 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -676,7 +676,7 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { return err } - failureMsg = lnwire.NewExpiryTooSoon(*update) + failureMsg = lnwire.NewExpiryTooSoon(update) default: return ErrUnsupportedFailureCode diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 3d5e75bd25..f24c5cb2fd 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2569,7 +2569,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewFeeInsufficient(amtToForward, *upd) + return lnwire.NewFeeInsufficient(amtToForward, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -2598,7 +2598,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // date with our current policy. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewIncorrectCltvExpiry( - incomingTimeout, *upd, + incomingTimeout, upd, ) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -2646,7 +2646,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewAmountBelowMinimum(amt, *upd) + return lnwire.NewAmountBelowMinimum(amt, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -2676,7 +2676,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, timeout, heightNow) cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewExpiryTooSoon(*upd) + return lnwire.NewExpiryTooSoon(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 94d9aada57..e2f6cd750a 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1103,7 +1103,9 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // sure that HTLC is not from the source node. if s.cfg.RejectHTLC { failure := NewDetailedLinkError( - &lnwire.FailChannelDisabled{}, + &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, OutgoingFailureForwardsDisabled, ) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index eb33c31f9e..6ae5a10065 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3380,7 +3380,9 @@ func TestHtlcNotifier(t *testing.T) { return getThreeHopEvents( channels, htlcID, ts, htlc, hops, &LinkError{ - msg: &lnwire.FailChannelDisabled{}, + msg: &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, FailureDetail: OutgoingFailureForwardsDisabled, }, preimage, @@ -4991,7 +4993,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) + require.Equal(t, aliceAlias, failMsg.Update.SCID()) case <-s2.quit: t.Fatal("switch shutting down, failed to forward packet") } @@ -5174,7 +5176,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive fail packet") } @@ -5373,7 +5375,7 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive failure") } @@ -5528,7 +5530,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - failScid := failureMsg.Update.ShortChannelID + failScid := failureMsg.Update.SCID() isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 1e9e26e52d..dc9f74d483 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1421,8 +1421,13 @@ func marshallWireError(msg lnwire.FailureMessage, response.Code = lnrpc.Failure_INVALID_REALM case *lnwire.FailExpiryTooSoon: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_EXPIRY_TOO_SOON - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update case *lnwire.FailExpiryTooFar: response.Code = lnrpc.Failure_EXPIRY_TOO_FAR @@ -1440,28 +1445,55 @@ func marshallWireError(msg lnwire.FailureMessage, response.OnionSha_256 = onionErr.OnionSHA256[:] case *lnwire.FailAmountBelowMinimum: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + + } + response.Code = lnrpc.Failure_AMOUNT_BELOW_MINIMUM - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailFeeInsufficient: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + + } + response.Code = lnrpc.Failure_FEE_INSUFFICIENT - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailIncorrectCltvExpiry: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_INCORRECT_CLTV_EXPIRY - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update response.CltvExpiry = onionErr.CltvExpiry case *lnwire.FailChannelDisabled: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_CHANNEL_DISABLED - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update response.Flags = uint32(onionErr.Flags) case *lnwire.FailTemporaryChannelFailure: + update, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE - response.ChannelUpdate = marshallChannelUpdate(onionErr.Update) + response.ChannelUpdate = update case *lnwire.FailRequiredNodeFeatureMissing: response.Code = lnrpc.Failure_REQUIRED_NODE_FEATURE_MISSING @@ -1499,24 +1531,33 @@ func marshallWireError(msg lnwire.FailureMessage, // marshallChannelUpdate marshalls a channel update as received over the wire to // the router rpc format. -func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { +func marshallChannelUpdate(update lnwire.ChannelUpdate) (*lnrpc.ChannelUpdate, + error) { + if update == nil { - return nil - } - - return &lnrpc.ChannelUpdate{ - Signature: update.Signature.RawBytes(), - ChainHash: update.ChainHash[:], - ChanId: update.ShortChannelID.ToUint64(), - Timestamp: update.Timestamp, - MessageFlags: uint32(update.MessageFlags), - ChannelFlags: uint32(update.ChannelFlags), - TimeLockDelta: uint32(update.TimeLockDelta), - HtlcMinimumMsat: uint64(update.HtlcMinimumMsat), - BaseFee: update.BaseFee, - FeeRate: update.FeeRate, - HtlcMaximumMsat: uint64(update.HtlcMaximumMsat), - ExtraOpaqueData: update.ExtraOpaqueData, + return nil, nil + } + + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + return &lnrpc.ChannelUpdate{ + Signature: upd.Signature.RawBytes(), + ChainHash: upd.ChainHash[:], + ChanId: upd.ShortChannelID.ToUint64(), + Timestamp: upd.Timestamp, + MessageFlags: uint32(upd.MessageFlags), + ChannelFlags: uint32(upd.ChannelFlags), + TimeLockDelta: uint32(upd.TimeLockDelta), + HtlcMinimumMsat: uint64(upd.HtlcMinimumMsat), + BaseFee: upd.BaseFee, + FeeRate: upd.FeeRate, + HtlcMaximumMsat: uint64(upd.HtlcMaximumMsat), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil + + default: + return nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) } } diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 608ac9d4ca..b1bf127b11 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -597,7 +597,7 @@ func (f *FailInvalidOnionKey) Error() string { // unable to pull out a fully valid version, then we'll fall back to the // regular parsing mechanism which includes the length prefix an NO type byte. func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, - chanUpdate *ChannelUpdate1, pver uint32) error { + pver uint32) (ChannelUpdate, error) { // Instantiate a LimitReader because there may be additional data // present after the channel update. Without limiting the stream, the @@ -610,28 +610,45 @@ func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, // buffer so we can decide how to parse the remainder of it. maybeTypeBytes, err := r.Peek(2) if err != nil { - return err + return nil, err + } + + var ( + typeInt = binary.BigEndian.Uint16(maybeTypeBytes) + chanUpdate ChannelUpdate + hasTypeBytes bool + ) + switch typeInt { + case MsgChannelUpdate: + chanUpdate = &ChannelUpdate1{} + hasTypeBytes = true + default: + // Some older nodes will not have the type prefix in front of + // their channel updates as there was initially some ambiguity + // in the spec. This should ony be the case for the + // ChannelUpdate2 message. + chanUpdate = &ChannelUpdate1{} } - // Some nodes well prefix an additional set of bytes in front of their - // channel updates. These bytes will _almost_ always be 258 or the type - // of the ChannelUpdate1 message. - typeInt := binary.BigEndian.Uint16(maybeTypeBytes) - if typeInt == MsgChannelUpdate { + if hasTypeBytes { // At this point it's likely the case that this is a channel // update message with its type prefixed, so we'll snip off the // first two bytes and parse it as normal. var throwAwayTypeBytes [2]byte _, err := r.Read(throwAwayTypeBytes[:]) if err != nil { - return err + return nil, err } } // At this pint, we've either decided to keep the entire thing, or snip // off the first two bytes. In either case, we can just read it as // normal. - return chanUpdate.Decode(r, pver) + if err = chanUpdate.Decode(r, pver); err != nil { + return nil, err + } + + return chanUpdate, nil } // FailTemporaryChannelFailure is if an otherwise unspecified transient error @@ -644,11 +661,13 @@ type FailTemporaryChannelFailure struct { // which caused the failure. // // NOTE: This field is optional. - Update *ChannelUpdate1 + Update ChannelUpdate } // NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. -func NewTemporaryChannelFailure(update *ChannelUpdate1) *FailTemporaryChannelFailure { +func NewTemporaryChannelFailure( + update ChannelUpdate) *FailTemporaryChannelFailure { + return &FailTemporaryChannelFailure{Update: update} } @@ -682,11 +701,14 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { } if length != 0 { - f.Update = &ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update } return nil @@ -717,12 +739,12 @@ type FailAmountBelowMinimum struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. func NewAmountBelowMinimum(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailAmountBelowMinimum { + update ChannelUpdate) *FailAmountBelowMinimum { return &FailAmountBelowMinimum{ HtlcMsat: htlcMsat, @@ -758,11 +780,16 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -773,7 +800,7 @@ func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we @@ -787,12 +814,12 @@ type FailFeeInsufficient struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewFeeInsufficient creates new instance of the FailFeeInsufficient. func NewFeeInsufficient(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailFeeInsufficient { + update ChannelUpdate) *FailFeeInsufficient { return &FailFeeInsufficient{ HtlcMsat: htlcMsat, Update: update, @@ -827,11 +854,14 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } + + f.Update = update - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + return nil } // Encode writes the failure in bytes stream. @@ -842,7 +872,7 @@ func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailIncorrectCltvExpiry is returned if outgoing cltv value does not match @@ -858,12 +888,12 @@ type FailIncorrectCltvExpiry struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. func NewIncorrectCltvExpiry(cltvExpiry uint32, - update ChannelUpdate1) *FailIncorrectCltvExpiry { + update ChannelUpdate) *FailIncorrectCltvExpiry { return &FailIncorrectCltvExpiry{ CltvExpiry: cltvExpiry, @@ -896,11 +926,14 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } + + f.Update = update - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + return nil } // Encode writes the failure in bytes stream. @@ -911,7 +944,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them @@ -921,11 +954,11 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { type FailExpiryTooSoon struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. -func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon { +func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { return &FailExpiryTooSoon{ Update: update, } @@ -954,18 +987,21 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Encode(w *bytes.Buffer, pver uint32) error { - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailChannelDisabled is returned if the channel is disabled, we tell them the @@ -980,11 +1016,13 @@ type FailChannelDisabled struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewChannelDisabled creates new instance of the FailChannelDisabled. -func NewChannelDisabled(flags uint16, update ChannelUpdate1) *FailChannelDisabled { +func NewChannelDisabled(flags uint16, + update ChannelUpdate) *FailChannelDisabled { + return &FailChannelDisabled{ Flags: flags, Update: update, @@ -1019,11 +1057,14 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -1034,7 +1075,7 @@ func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not @@ -1459,7 +1500,7 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { // writeOnionErrorChanUpdate writes out a ChannelUpdate1 using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate ChannelUpdate, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 4f230bd22f..d1cdc91e56 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,7 +20,7 @@ var ( testType = uint64(3) testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) - testChannelUpdate = ChannelUpdate1{ + testChannelUpdate = &ChannelUpdate1{ Signature: sig, ShortChannelID: NewShortChanIDFromInt(1), Timestamp: 1, @@ -46,7 +46,7 @@ var onionFailures = []FailureMessage{ NewInvalidOnionVersion(testOnionHash), NewInvalidOnionHmac(testOnionHash), NewInvalidOnionKey(testOnionHash), - NewTemporaryChannelFailure(&testChannelUpdate), + NewTemporaryChannelFailure(testChannelUpdate), NewTemporaryChannelFailure(nil), NewAmountBelowMinimum(testAmount, testChannelUpdate), NewFeeInsufficient(testAmount, testChannelUpdate), @@ -136,9 +136,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // Now that we have the set of bytes encoded, we'll ensure that we're // able to decode it using our compatibility method, as it's a regular // encoded channel update message. - var newChanUpdate ChannelUpdate1 - err := parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate, 0, + newChanUpdate, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -163,9 +162,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // We should be able to properly parse the encoded channel update // message even with the extra two bytes. - var newChanUpdate2 ChannelUpdate1 - err = parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate2, 0, + newChanUpdate2, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -184,7 +182,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // raw serialized length. var b bytes.Buffer update := testChannelUpdate - trueUpdateLength, err := WriteMessage(&b, &update, 0) + trueUpdateLength, err := WriteMessage(&b, update, 0) if err != nil { t.Fatalf("unable to write update: %v", err) } @@ -192,7 +190,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // Next, we'll use the function to encode the update as we would in a // onion error message. var errorBuf bytes.Buffer - err = writeOnionErrorChanUpdate(&errorBuf, &update, 0) + err = writeOnionErrorChanUpdate(&errorBuf, update, 0) require.NoError(t, err, "unable to encode onion error") // Finally, read the length encoded and ensure that it matches the raw diff --git a/routing/ann_validation.go b/routing/ann_validation.go index 85f8d42833..ec61b866bb 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -194,18 +194,36 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { // signed by the node's private key, and (2) that the announcement's message // flags and optional fields are sane. func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, - a *lnwire.ChannelUpdate1) error { + a lnwire.ChannelUpdate) error { - if err := ValidateChannelUpdateFields(capacity, a); err != nil { + update, ok := a.(*lnwire.ChannelUpdate1) + if !ok { + return fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", a) + } + + if err := ValidateChannelUpdateFields(capacity, update); err != nil { return err } - return VerifyChannelUpdateSignature(a, pubKey) + return VerifyChannelUpdateSignature(update, pubKey) } // VerifyChannelUpdateSignature verifies that the channel update message was // signed by the party with the given node public key. -func VerifyChannelUpdateSignature(msg *lnwire.ChannelUpdate1, +func VerifyChannelUpdateSignature(msg lnwire.ChannelUpdate, + pubKey *btcec.PublicKey) error { + + switch m := msg.(type) { + case *lnwire.ChannelUpdate1: + return verifyChannelUpdate1Signature(m, pubKey) + default: + return fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", msg) + } +} + +func verifyChannelUpdate1Signature(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey) error { data, err := msg.DataToSign() diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index 4a0f738715..f68e433d30 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -197,7 +197,7 @@ func TestMissionControl(t *testing.T) { // A node level failure should bring probability of all known channels // back to zero. - ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{})) + ctx.reportFailure(0, lnwire.NewExpiryTooSoon(&lnwire.ChannelUpdate1{})) ctx.expectP(1000, 0) // Check whether history snapshot looks sane. @@ -219,14 +219,14 @@ func TestMissionControlChannelUpdate(t *testing.T) { // Report a policy related failure. Because it is the first, we don't // expect a penalty. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, testAprioriHopProbability) // Report another failure for the same channel. We expect it to be // pruned. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, 0) } diff --git a/routing/mock_test.go b/routing/mock_test.go index 4f765d89e9..33a491c580 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -173,7 +173,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, return r, nil } -func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1, +func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ lnwire.ChannelUpdate, _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool { return false @@ -675,7 +675,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, return args.Get(0).(*route.Route), args.Error(1) } -func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, +func (m *mockPaymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index d84c264c20..76f31c93ad 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -885,7 +885,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, // SendToRoute where there's no payment lifecycle. if p.paySession != nil { policy = p.paySession.GetAdditionalEdgePolicy( - errSource, update.ShortChannelID.ToUint64(), + errSource, update.SCID().ToUint64(), ) if policy != nil { isAdditionalEdge = true @@ -895,7 +895,8 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, // Apply channel update to additional edge policy. if isAdditionalEdge { if !p.paySession.UpdateAdditionalEdge( - update, errSource, policy) { + update, errSource, policy, + ) { log.Debugf("Invalid channel update received: node=%v", errVertex) diff --git a/routing/payment_session.go b/routing/payment_session.go index e54d5e2760..3c294e7ca7 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -143,7 +143,7 @@ type PaymentSession interface { // (private channels) and applies the update from the message. Returns // a boolean to indicate whether the update has been applied without // error. - UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, + UpdateAdditionalEdge(msg lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query @@ -404,7 +404,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // validates the message signature and checks it's up to date, then applies the // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. -func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, +func (p *paymentSession) UpdateAdditionalEdge(msg lnwire.ChannelUpdate, pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { // Validate the message signature. @@ -416,9 +416,9 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, } // Update channel policy for the additional edge. - policy.TimeLockDelta = msg.TimeLockDelta - policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee) - policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate) + policy.TimeLockDelta = msg.GetTimeLock() + policy.FeeBaseMSat = msg.GetBaseFee() + policy.FeeProportionalMillionths = msg.GetFeeRate() log.Debugf("New private channel update applied: %v", newLogClosure(func() string { return spew.Sdump(msg) })) diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index 8eb39c01d9..81507427fd 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -94,7 +94,9 @@ var resultTestCases = []resultTestCase{ name: "fail expiry too soon", route: &routeFourHop, failureSrcIdx: 3, - failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), + failure: lnwire.NewExpiryTooSoon( + &lnwire.ChannelUpdate1{}, + ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ @@ -196,7 +198,9 @@ var resultTestCases = []resultTestCase{ name: "fail fee insufficient intermediate", route: &routeFourHop, failureSrcIdx: 2, - failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + failure: lnwire.NewFeeInsufficient( + 0, &lnwire.ChannelUpdate1{}, + ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ diff --git a/routing/router.go b/routing/router.go index c0cc8cca6d..1c3511e444 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2563,20 +2563,20 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi, // extractChannelUpdate examines the error and extracts the channel update. func (r *ChannelRouter) extractChannelUpdate( - failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 { + failure lnwire.FailureMessage) lnwire.ChannelUpdate { - var update *lnwire.ChannelUpdate1 + var update lnwire.ChannelUpdate switch onionErr := failure.(type) { case *lnwire.FailExpiryTooSoon: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailAmountBelowMinimum: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailFeeInsufficient: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailIncorrectCltvExpiry: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailChannelDisabled: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailTemporaryChannelFailure: update = onionErr.Update } @@ -2586,8 +2586,8 @@ func (r *ChannelRouter) extractChannelUpdate( // applyChannelUpdate validates a channel update and if valid, applies it to the // database. It returns a bool indicating whether the updates were successful. -func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { - ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) +func (r *ChannelRouter) applyChannelUpdate(msg lnwire.ChannelUpdate) bool { + ch, _, _, err := r.GetChannelByID(msg.SCID()) if err != nil { log.Errorf("Unable to retrieve channel by id: %v", err) return false @@ -2595,39 +2595,27 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { var pubKey *btcec.PublicKey - switch msg.ChannelFlags & lnwire.ChanUpdateDirection { - case 0: + if msg.IsNode1() { pubKey, _ = ch.NodeKey1() - - case 1: + } else { pubKey, _ = ch.NodeKey2() } - // Exit early if the pubkey cannot be decided. - if pubKey == nil { - log.Errorf("Unable to decide pubkey with ChannelFlags=%v", - msg.ChannelFlags) + err = ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg) + if err != nil { + log.Errorf("Unable to validate channel update: %v", err) return false } - err = ValidateChannelUpdateAnn(pubKey, ch.GetCapacity(), msg) + edgePolicy, err := channeldb.EdgePolicyFromUpdate(msg) if err != nil { - log.Errorf("Unable to validate channel update: %v", err) + log.Errorf("Unable to convert update message to edge "+ + "policy: %v", err) + return false } - err = r.UpdateEdge(&channeldb.ChannelEdgePolicy1{ - SigBytes: msg.Signature.ToSignatureBytes(), - ChannelID: msg.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(msg.Timestamp), 0), - MessageFlags: msg.MessageFlags, - ChannelFlags: msg.ChannelFlags, - TimeLockDelta: msg.TimeLockDelta, - MinHTLC: msg.HtlcMinimumMsat, - MaxHTLC: msg.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate), - }) + err = r.UpdateEdge(edgePolicy) if err != nil && !IsError(err, ErrIgnored, ErrOutdated) { log.Errorf("Unable to apply channel update: %v", err) return false diff --git a/routing/router_test.go b/routing/router_test.go index d714abb43d..8b9c88f04c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -527,9 +527,8 @@ func TestChannelUpdateValidation(t *testing.T) { func(firstHop lnwire.ShortChannelID) ([32]byte, error) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, - }, - 1, + Update: &errChanUpdate, + }, 1, ) }) @@ -644,7 +643,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) } @@ -756,7 +755,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -885,7 +884,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -985,7 +984,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) } @@ -1032,7 +1031,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) } @@ -2921,7 +2920,7 @@ func TestSendToRouteStructuredError(t *testing.T) { testCases := map[int]lnwire.FailureMessage{ finalHopIndex: lnwire.NewFailIncorrectDetails(payAmt, 100), 1: &lnwire.FailFeeInsufficient{ - Update: lnwire.ChannelUpdate1{}, + Update: &lnwire.ChannelUpdate1{}, }, } From e54aac336723a08280e1204ed532cb1d1925490f Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 11:50:47 +0200 Subject: [PATCH 26/33] channeldb: add ChannelEdgeInfo2 impl --- channeldb/graph.go | 376 ++++++++++++++++++++++++++++++- channeldb/graph_test.go | 155 +++++++++---- lnwire/channel_announcement_2.go | 4 + 3 files changed, 484 insertions(+), 51 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index fd8b1e2504..1f4a41ea60 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" @@ -26,6 +27,7 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -3399,6 +3401,10 @@ func (c *ChannelEdgeInfo1) BitcoinKey2() (*btcec.PublicKey, error) { return key, nil } +// A compile-time check to ensure that ChannelEdgeInfo1 implements +// modesl.ChannelEdgeInfo. +var _ models.ChannelEdgeInfo = (*ChannelEdgeInfo1)(nil) + // FetchOtherNode attempts to fetch the full LightningNode that's opposite of // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other @@ -3453,10 +3459,6 @@ func FetchOtherNode(db kvdb.Backend, tx kvdb.RTx, edge models.ChannelEdgeInfo, return targetNode, err } -// A compile-time check to ensure that ChannelEdgeInfo1 implements -// modesl.ChannelEdgeInfo. -var _ models.ChannelEdgeInfo = (*ChannelEdgeInfo1)(nil) - // ChannelAuthProof1 is the authentication proof (the signature portion) for a // channel. Using the four signatures contained in the struct, and some // auxiliary knowledge (the funding script, node identities, and outpoint) nodes @@ -4507,24 +4509,66 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } +// edgeInfoEncodingType indicate how the bytes for a channel edge have been +// serialised. +type edgeInfoEncodingType uint8 + +const ( + // edgeInfo2EncodingType will be used as a prefix for edge's advertised + // using the ChannelAnnouncement2 message. The type indicates how the + // bytes following should be deserialized. + edgeInfo2EncodingType edgeInfoEncodingType = 0 +) + // putChanEdgeInfo encodes and writes the given edge to the edge index bucket. // The encoding used will depend on the channel type. func putChanEdgeInfo(edgeIndex kvdb.RwBucket, edgeInfo models.ChannelEdgeInfo, chanID [8]byte) error { - var b bytes.Buffer + var ( + b bytes.Buffer + withTypeByte bool + typeByte edgeInfoEncodingType + serialize func(w io.Writer) error + ) switch info := edgeInfo.(type) { case *ChannelEdgeInfo1: - err := serializeChanEdgeInfo1(&b, info, chanID) - if err != nil { - return err + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo1(&b, info, chanID) + } + case *ChannelEdgeInfo2: + withTypeByte = true + typeByte = edgeInfo2EncodingType + + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo2(&b, info) } default: return fmt.Errorf("unhandled implementation of "+ "ChannelEdgeInfo: %T", edgeInfo) } + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := b.Write([]byte{chanEdgeNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = b.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + err := serialize(&b) + if err != nil { + return err + } + return edgeIndex.Put(chanID[:], b.Bytes()) } @@ -4618,7 +4662,26 @@ func deserializeChanEdgeInfo(reader io.Reader) (models.ChannelEdgeInfo, error) { return deserializeChanEdgeInfo1(r) } - return nil, fmt.Errorf("unknown channel edge encoding") + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgeInfoEncodingType(scratch[0]) + switch encoding { + case edgeInfo2EncodingType: + return deserializeChanEdgeInfo2(r) + + default: + return nil, fmt.Errorf("unknown edge info encoding type: %d", + encoding) + } } func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { @@ -5063,3 +5126,298 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { return edge, nil } + +const ( + EdgeInfo2MsgType = tlv.Type(0) + EdgeInfo2ChanPoint = tlv.Type(1) + EdgeInfo2Sig = tlv.Type(2) +) + +type ChannelEdgeInfo2 struct { + lnwire.ChannelAnnouncement2 + + // ChannelPoint is the funding outpoint of the channel. This can be + // used to uniquely identify the channel within the channel graph. + ChannelPoint wire.OutPoint + + // AuthProof is the authentication proof for this channel. + AuthProof *ChannelAuthProof2 + + nodeKey1 *btcec.PublicKey + nodeKey2 *btcec.PublicKey + bitcoinKey1 *btcec.PublicKey + bitcoinKey2 *btcec.PublicKey +} + +// Copy returns a copy of the ChannelEdgeInfo. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Copy() models.ChannelEdgeInfo { + return &ChannelEdgeInfo2{ + ChannelAnnouncement2: lnwire.ChannelAnnouncement2{ + ChainHash: c.ChainHash, + Features: c.Features, + ShortChannelID: c.ShortChannelID, + Capacity: c.Capacity, + NodeID1: c.NodeID1, + NodeID2: c.NodeID2, + BitcoinKey1: c.BitcoinKey1, + BitcoinKey2: c.BitcoinKey2, + MerkleRootHash: c.MerkleRootHash, + ExtraOpaqueData: c.ExtraOpaqueData, + }, + ChannelPoint: c.ChannelPoint, + AuthProof: c.AuthProof, + } +} + +// Node1Bytes returns bytes of the public key of node 1. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Node1Bytes() [33]byte { + return c.NodeID1 +} + +// Node2Bytes returns bytes of the public key of node 2. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Node2Bytes() [33]byte { + return c.NodeID2 +} + +// GetChainHash returns the hash of the genesis block of the chain that the edge +// is on. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChainHash() chainhash.Hash { + return c.ChainHash +} + +// GetChanID returns the channel ID. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChanID() uint64 { + return c.ShortChannelID.ToUint64() +} + +// GetAuthProof returns the ChannelAuthProof for the edge. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetAuthProof() models.ChannelAuthProof { + // Cant just return AuthProof cause you then run into the + // nil interface gotcha. + if c.AuthProof == nil { + return nil + } + + return c.AuthProof +} + +// GetCapacity returns the capacity of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetCapacity() btcutil.Amount { + return btcutil.Amount(c.Capacity) +} + +// SetAuthProof sets the proof of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) SetAuthProof(proof models.ChannelAuthProof) error { + if proof == nil { + c.AuthProof = nil + + return nil + } + + p, ok := proof.(*ChannelAuthProof2) + if !ok { + return fmt.Errorf("expected type ChannelAuthProof2 for "+ + "ChannelEdgeInfo2, got %T", proof) + } + + c.AuthProof = p + + return nil +} + +// GetChanPoint returns the outpoint of the funding transaction of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChanPoint() wire.OutPoint { + return c.ChannelPoint +} + +// FundingScript returns the pk script for the funding output of the +// channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) FundingScript() ([]byte, error) { + pubKey1, err := btcec.ParsePubKey(c.BitcoinKey1[:]) + if err != nil { + return nil, err + } + pubKey2, err := btcec.ParsePubKey(c.BitcoinKey2[:]) + if err != nil { + return nil, err + } + + fundingScript, _, err := input.GenTaprootFundingScript( + pubKey1, pubKey2, 0, + ) + if err != nil { + return nil, err + } + + return fundingScript, nil +} + +// NodeKey1 is the identity public key of the "first" node that was involved in +// the creation of this channel. A node is considered "first" if the +// lexicographical ordering the its serialized public key is "smaller" than +// that of the other node involved in channel creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) NodeKey1() (*btcec.PublicKey, error) { + if c.nodeKey1 != nil { + return c.nodeKey1, nil + } + + key, err := btcec.ParsePubKey(c.NodeID1[:]) + if err != nil { + return nil, err + } + c.nodeKey1 = key + + return key, nil +} + +// NodeKey2 is the identity public key of the "second" node that was +// involved in the creation of this channel. A node is considered +// "second" if the lexicographical ordering the its serialized public +// key is "larger" than that of the other node involved in channel +// creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) NodeKey2() (*btcec.PublicKey, error) { + if c.nodeKey2 != nil { + return c.nodeKey2, nil + } + + key, err := btcec.ParsePubKey(c.NodeID2[:]) + if err != nil { + return nil, err + } + c.nodeKey2 = key + + return key, nil +} + +func serializeChanEdgeInfo2(w io.Writer, edge *ChannelEdgeInfo2) error { + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + serializedMsg, err := edge.DataToSign() + if err != nil { + return err + } + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &serializedMsg), + tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edge.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + ), + } + + if edge.AuthProof != nil { + records = append( + records, + tlv.MakePrimitiveRecord( + EdgeInfo2Sig, &edge.AuthProof.SchnorrSigBytes, + ), + ) + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func deserializeChanEdgeInfo2(r io.Reader) (*ChannelEdgeInfo2, error) { + var ( + edgeInfo ChannelEdgeInfo2 + msgBytes []byte + sigBytes []byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &msgBytes), + tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edgeInfo.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + ), + tlv.MakePrimitiveRecord(EdgeInfo2Sig, &sigBytes), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + typeMap, err := stream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + reader := bytes.NewReader(msgBytes) + err = edgeInfo.ChannelAnnouncement2.DecodeTLVRecords(reader) + if err != nil { + return nil, err + } + + if _, ok := typeMap[EdgeInfo2Sig]; ok { + edgeInfo.AuthProof = &ChannelAuthProof2{ + SchnorrSigBytes: sigBytes, + } + } + + return &edgeInfo, nil +} + +type ChannelAuthProof2 struct { + // SchnorrSigBytes are the raw bytes of the encoded schnorr signature. + SchnorrSigBytes []byte + + // schnorrSig is the cached instance of the schnorr signature. + schnorrSig *schnorr.Signature +} + +// A compile-time check to ensure that ChannelEdgeInfo2 implements +// models.ChannelEdgeInfo. +var _ models.ChannelEdgeInfo = (*ChannelEdgeInfo2)(nil) + +func encodeOutpoint(w io.Writer, val interface{}, _ *[8]byte) error { + if o, ok := val.(*wire.OutPoint); ok { + return writeOutpoint(w, o) + } + + return tlv.NewTypeForEncodingErr(val, "*wire.Outpoint") +} + +func decodeOutpoint(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if o, ok := val.(*wire.OutPoint); ok { + return readOutpoint(r, o) + } + return tlv.NewTypeForDecodingErr(val, "*wire.Outpoint", l, l) +} diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b470ba4b25..260a169980 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -36,13 +36,16 @@ var ( "[2001:db8:85a3:0:0:8a2e:370:7334]:80") testAddrs = []net.Addr{testAddr, anotherAddr} - testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") - testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") - testRScalar = new(btcec.ModNScalar) - testSScalar = new(btcec.ModNScalar) - _ = testRScalar.SetByteSlice(testRBytes) - _ = testSScalar.SetByteSlice(testSBytes) - testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testSchnorrSigStr, _ = hex.DecodeString("04E7F9037658A92AFEB4F2" + + "5BAE5339E3DDCA81A353493827D26F16D92308E49E2A25E9220867" + + "8A2DF86970DA91B03A8AF8815A8A60498B358DAF560B347AA557") testFeatures = lnwire.NewFeatureVector( lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), @@ -385,7 +388,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (ChannelEdgeInfo1, lnwire.ShortChannelID) { + node1, node2 *LightningNode) (*ChannelEdgeInfo1, lnwire.ShortChannelID) { shortChanID := lnwire.ShortChannelID{ BlockHeight: height, @@ -417,7 +420,48 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - return edgeInfo, shortChanID + return &edgeInfo, shortChanID +} + +func createEdge2(height, txIndex uint32, txPosition uint16, + outPointIndex uint32, node1, node2 *LightningNode) (*ChannelEdgeInfo2, + lnwire.ShortChannelID) { + + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + node1Pub, _ := node1.PubKey() + node2Pub, _ := node2.PubKey() + edgeInfo := ChannelEdgeInfo2{ + ChannelAnnouncement2: lnwire.ChannelAnnouncement2{ + ChainHash: key, + Features: lnwire.RawFeatureVector{}, + ShortChannelID: shortChanID, + Capacity: 9000, + }, + ChannelPoint: outpoint, + AuthProof: &ChannelAuthProof2{ + SchnorrSigBytes: testSchnorrSigStr, + }, + } + + copy(edgeInfo.NodeID1[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeID2[:], node2Pub.SerializeCompressed()) + + var btc1, btc2 [33]byte + copy(btc1[:], node1Pub.SerializeCompressed()) + copy(btc2[:], node2Pub.SerializeCompressed()) + edgeInfo.BitcoinKey1 = &btc1 + edgeInfo.BitcoinKey2 = &btc2 + + return &edgeInfo, shortChanID } // TestDisconnectBlockAtHeight checks that the pruned state of the channel @@ -477,20 +521,20 @@ func TestDisconnectBlockAtHeight(t *testing.T) { edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) // Now add all these new edges to the database. - if err := graph.AddChannelEdge(&edgeInfo); err != nil { + if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - if err := graph.AddChannelEdge(&edgeInfo2); err != nil { + if err := graph.AddChannelEdge(edgeInfo2); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - if err := graph.AddChannelEdge(&edgeInfo3); err != nil { + if err := graph.AddChannelEdge(edgeInfo3); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) - assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo2) - assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo) + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo2) + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo3) // Call DisconnectBlockAtHeight, which should prune every channel // that has a funding height of 'height' or greater. @@ -500,7 +544,7 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } assertNoEdge(t, graph, edgeInfo.ChannelID) assertNoEdge(t, graph, edgeInfo2.ChannelID) - assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo3) // The two edges should have been removed. if len(removed) != 2 { @@ -561,11 +605,30 @@ func assertEdgeInfoEqual(t *testing.T, e1, e2 models.ChannelEdgeInfo) { require.True(t, ok) assertEdgeInfo1Equal(t, edge1, edge2) + case *ChannelEdgeInfo2: + edge2, ok := e2.(*ChannelEdgeInfo2) + require.True(t, ok) + + assertEdgeInfo2Equal(t, edge1, edge2) default: t.Fatalf("unhandled ChannelEdgeInfo type: %T", e1) } } +func assertEdgeInfo2Equal(t *testing.T, e1 *ChannelEdgeInfo2, + e2 *ChannelEdgeInfo2) { + + if e1.Features.IsEmpty() { + require.True(t, e2.Features.IsEmpty()) + } else { + require.Equal(t, e1.Features, e2.Features) + } + e1.Features = lnwire.RawFeatureVector{} + e2.Features = lnwire.RawFeatureVector{} + + require.Equal(t, e1, e2) +} + func assertEdgeInfo1Equal(t *testing.T, e1 *ChannelEdgeInfo1, e2 *ChannelEdgeInfo1) { @@ -1631,10 +1694,10 @@ func TestHighestChanID(t *testing.T) { edge1, _ := createEdge(10, 0, 0, 0, node1, node2) edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edge1); err != nil { + if err := graph.AddChannelEdge(edge1); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - if err := graph.AddChannelEdge(&edge2); err != nil { + if err := graph.AddChannelEdge(edge2); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -1651,7 +1714,7 @@ func TestHighestChanID(t *testing.T) { // If we add another edge, then the current best chan ID should be // updated as well. edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edge3); err != nil { + if err := graph.AddChannelEdge(edge3); err != nil { t.Fatalf("unable to create channel edge: %v", err) } bestID, err = graph.HighestChanID() @@ -1706,7 +1769,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { uint32(i*10), 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(channel); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -1735,7 +1798,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { } edges = append(edges, ChannelEdge{ - Info: &channel, + Info: channel, Policy1: edge1, Policy2: edge2, }) @@ -1977,7 +2040,7 @@ func TestFilterKnownChanIDs(t *testing.T) { uint32(i*10), 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(channel); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -1990,7 +2053,7 @@ func TestFilterKnownChanIDs(t *testing.T) { channel, chanID := createEdge( uint32(i*10+1), 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(channel); err != nil { t.Fatalf("unable to create channel edge: %v", err) } err := graph.DeleteChannelEdges(false, true, channel.ChannelID) @@ -2085,14 +2148,14 @@ func TestFilterChannelRange(t *testing.T) { channel1, chanID1 := createEdge( chanHeight, uint32(i+1), 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel1); err != nil { + if err := graph.AddChannelEdge(channel1); err != nil { t.Fatalf("unable to create channel edge: %v", err) } channel2, chanID2 := createEdge( chanHeight, uint32(i+2), 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel2); err != nil { + if err := graph.AddChannelEdge(channel2); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -2197,11 +2260,21 @@ func TestFetchChanInfos(t *testing.T) { edges := make([]ChannelEdge, 0, numChans) edgeQuery := make([]uint64, 0, numChans) for i := 0; i < numChans; i++ { - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, + var ( + channel models.ChannelEdgeInfo + chanID lnwire.ShortChannelID ) + if i%2 == 0 { + channel, chanID = createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + } else { + channel, chanID = createEdge2( + uint32(i*10), 0, 0, 0, node1, node2, + ) + } - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(channel); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -2214,9 +2287,7 @@ func TestFetchChanInfos(t *testing.T) { edge1.ChannelFlags = 0 edge1.Node = node2 edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } + require.NoError(t, graph.UpdateEdgePolicy(edge1)) edge2 := newEdgePolicy( chanID.ToUint64(), graph.db, updateTime.Unix(), @@ -2229,7 +2300,7 @@ func TestFetchChanInfos(t *testing.T) { } edges = append(edges, ChannelEdge{ - Info: &channel, + Info: channel, Policy1: edge1, Policy2: edge2, }) @@ -2246,7 +2317,7 @@ func TestFetchChanInfos(t *testing.T) { zombieChan, zombieChanID := createEdge( 666, 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&zombieChan); err != nil { + if err := graph.AddChannelEdge(zombieChan); err != nil { t.Fatalf("unable to create channel edge: %v", err) } err = graph.DeleteChannelEdges(false, true, zombieChan.ChannelID) @@ -2301,7 +2372,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { uint32(0), 0, 0, 0, node1, node2, ) - if err := graph.AddChannelEdge(&channel); err != nil { + if err := graph.AddChannelEdge(channel); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -2408,7 +2479,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { // With the two nodes created, we'll now create a random channel, as // well as two edges in the database with distinct update times. edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { + if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to add edge: %v", err) } @@ -2559,7 +2630,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll now add a new edge to the graph, but only actually advertise // the edge of *one* of the nodes. edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { + if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to add edge: %v", err) } @@ -2614,7 +2685,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // We'll now create an edge between the two nodes, as a result, node2 // should be inserted into the database as a shell node. edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { + if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to add edge: %v", err) } @@ -2728,7 +2799,7 @@ func TestNodeIsPublic(t *testing.T) { // After creating all of our nodes and edges, we'll add them to each // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} - edges := []*ChannelEdgeInfo1{&aliceBobEdge, &bobCarolEdge} + edges := []*ChannelEdgeInfo1{aliceBobEdge, bobCarolEdge} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for _, graph := range graphs { for _, node := range nodes { @@ -2809,7 +2880,7 @@ func TestNodeIsPublic(t *testing.T) { } bobCarolEdge.AuthProof = nil - if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { + if err := graph.AddChannelEdge(bobCarolEdge); err != nil { t.Fatalf("unable to add edge: %v", err) } } @@ -3344,7 +3415,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { // Create a third edge, this with a block height of 155. edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) - edges := []ChannelEdgeInfo1{edgeInfo, edgeInfo2, edgeInfo3} + edges := []*ChannelEdgeInfo1{edgeInfo, edgeInfo2, edgeInfo3} errChan := make(chan error, len(edges)) errTimeout := errors.New("timeout adding batched channel") @@ -3352,11 +3423,11 @@ func TestBatchedAddChannelEdge(t *testing.T) { var wg sync.WaitGroup for _, edge := range edges { wg.Add(1) - go func(edge ChannelEdgeInfo1) { + go func(edge *ChannelEdgeInfo1) { defer wg.Done() select { - case errChan <- graph.AddChannelEdge(&edge): + case errChan <- graph.AddChannelEdge(edge): case <-time.After(2 * time.Second): errChan <- errTimeout } diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 3cfed664ca..2dd1506b74 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -117,6 +117,10 @@ func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { } c.Signature.ForceSchnorr() + return c.DecodeTLVRecords(r) +} + +func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { // First extract into extra opaque data. var tlvRecords ExtraOpaqueData if err := ReadElements(r, &tlvRecords); err != nil { From 923f9c38f7d2da0eb8387da3393b984faeb45231 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 14:07:05 +0200 Subject: [PATCH 27/33] introduce ChannelEdgePolicyWithNode This is so that we can take the Node variable out of ChannelEdgePolicy1 so that it can be common amoung future ChannelEdgePolicy (interface) implementations. --- autopilot/graph.go | 54 +++++---- channeldb/graph.go | 169 ++++++++++++++++----------- channeldb/graph_cache.go | 13 ++- channeldb/graph_cache_test.go | 33 ++++-- channeldb/graph_test.go | 120 ++++++++++--------- discovery/chan_series.go | 8 +- discovery/gossiper.go | 44 ++++--- discovery/gossiper_test.go | 27 +++-- funding/manager.go | 8 +- funding/manager_test.go | 2 +- lnrpc/devrpc/dev_server.go | 6 +- lnrpc/invoicesrpc/addinvoice.go | 11 +- lnrpc/invoicesrpc/addinvoice_test.go | 84 ++++++------- netann/chan_status_manager_test.go | 53 +++++---- netann/channel_announcement.go | 10 +- netann/channel_update.go | 9 +- netann/interface.go | 3 +- routing/localchans/manager.go | 6 +- routing/localchans/manager_test.go | 12 +- routing/notifications.go | 2 +- routing/notifications_test.go | 28 ++--- routing/pathfind_test.go | 72 ++++++------ routing/router.go | 23 ++-- routing/router_test.go | 126 +++++++++++--------- rpcserver.go | 21 +++- server.go | 6 +- 26 files changed, 537 insertions(+), 413 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index aba7d8365a..708a2ff8ce 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -91,7 +91,7 @@ func (d dbNode) Addrs() []net.Addr { func (d dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.node.ForEachChannel(d.db, d.tx, func(db kvdb.Backend, tx kvdb.RTx, ei models.ChannelEdgeInfo, ep, - _ *channeldb.ChannelEdgePolicy1) error { + _ *channeldb.ChannelEdgePolicyWithNode) error { // Skip channels for which no outgoing edge policy is available. // @@ -236,33 +236,41 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err := d.db.AddChannelEdge(edge); err != nil { return nil, nil, err } - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: chanID.ToUint64(), - LastUpdate: time.Now(), - TimeLockDelta: 10, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - MessageFlags: 1, - ChannelFlags: 0, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: chanID.ToUint64(), + LastUpdate: time.Now(), + TimeLockDelta: 10, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis( + capacity, + ), + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + MessageFlags: 1, + ChannelFlags: 0, + }, } if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { return nil, nil, err } - edgePolicy = &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: chanID.ToUint64(), - LastUpdate: time.Now(), - TimeLockDelta: 10, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - MessageFlags: 1, - ChannelFlags: 1, + edgePolicy = &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: chanID.ToUint64(), + LastUpdate: time.Now(), + TimeLockDelta: 10, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis( + capacity, + ), + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + MessageFlags: 1, + ChannelFlags: 1, + }, } if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { return nil, nil, err diff --git a/channeldb/graph.go b/channeldb/graph.go index 1f4a41ea60..992d0998fe 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -241,7 +241,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, } err = g.ForEachChannel(func(info models.ChannelEdgeInfo, - policy1, policy2 *ChannelEdgePolicy1) error { + policy1, policy2 *ChannelEdgePolicyWithNode) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -271,10 +271,10 @@ type channelMapKey struct { // getChannelMap loads all channel edge policies from the database and stores // them in a map. func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( - map[channelMapKey]*ChannelEdgePolicy1, error) { + map[channelMapKey]*ChannelEdgePolicyWithNode, error) { // Create a map to store all channel edge policies. - channelMap := make(map[channelMapKey]*ChannelEdgePolicy1) + channelMap := make(map[channelMapKey]*ChannelEdgePolicyWithNode) err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error { // Skip embedded buckets. @@ -302,7 +302,7 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( } edgeReader := bytes.NewReader(edgeBytes) - edge, err := deserializeChanEdgePolicyRaw( + edge, pubKey, err := deserializeChanEdgePolicyRaw( edgeReader, ) @@ -316,7 +316,15 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } - channelMap[key] = edge + var pub [33]byte + copy(pub[:], pubKey) + + channelMap[key] = &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: *edge, + Node: &LightningNode{ + PubKeyBytes: pub, + }, + } return nil }) @@ -425,7 +433,7 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // for that particular channel edge routing policy will be passed into the // callback. func (c *ChannelGraph) ForEachChannel(cb func(models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { + *ChannelEdgePolicyWithNode, *ChannelEdgePolicyWithNode) error) error { return c.db.View(func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -497,7 +505,8 @@ func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, node route.Vertex, } dbCallback := func(_ kvdb.Backend, tx kvdb.RTx, - e models.ChannelEdgeInfo, p1, p2 *ChannelEdgePolicy1) error { + e models.ChannelEdgeInfo, p1, + p2 *ChannelEdgePolicyWithNode) error { var cachedInPolicy *CachedEdgePolicy if p2 != nil { @@ -572,7 +581,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, err := node.ForEachChannel(c.db, tx, func(_ kvdb.Backend, tx kvdb.RTx, e models.ChannelEdgeInfo, - p1 *ChannelEdgePolicy1, p2 *ChannelEdgePolicy1) error { + p1, p2 *ChannelEdgePolicyWithNode) error { toNodeCallback := func() route.Vertex { return node.PubKeyBytes @@ -1875,11 +1884,11 @@ type ChannelEdge struct { // Policy1 points to the "first" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy1 *ChannelEdgePolicy1 + Policy1 *ChannelEdgePolicyWithNode // Policy2 points to the "second" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy2 *ChannelEdgePolicy1 + Policy2 *ChannelEdgePolicyWithNode } // ChanUpdatesInHorizon returns all the known channel edges which have at least @@ -2319,7 +2328,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { } func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, - edge1, edge2 *ChannelEdgePolicy1) error { + edge1, edge2 *ChannelEdgePolicyWithNode) error { // First, we'll fetch the edge update index bucket which currently // stores an entry for the channel we're about to delete. @@ -2464,7 +2473,7 @@ func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, // marked with the correct lagging channel since we received an update from only // one side. func makeZombiePubkeys(info models.ChannelEdgeInfo, - e1, e2 *ChannelEdgePolicy1) ([33]byte, [33]byte) { + e1, e2 *ChannelEdgePolicyWithNode) ([33]byte, [33]byte) { var ( node1Bytes = info.Node1Bytes() @@ -2499,7 +2508,7 @@ func makeZombiePubkeys(info models.ChannelEdgeInfo, // updated, otherwise it's the second node's information. The node ordering is // determined by the lexicographical ordering of the identity public keys of the // nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy1, +func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicyWithNode, op ...batch.SchedulerOption) error { var ( @@ -2547,7 +2556,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy1, return c.chanScheduler.Execute(r) } -func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy1, isUpdate1 bool) { +func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicyWithNode, + isUpdate1 bool) { + // If an entry for this channel is found in reject cache, we'll modify // the entry with the updated timestamp for the direction that was just // written. If the edge doesn't exist, we'll load the cache entry lazily @@ -2579,7 +2590,7 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy1, isUpdate1 bool) { // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy1, +func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicyWithNode, graphCache *GraphCache) (bool, error) { edges := tx.ReadWriteBucket(edgeBucket) @@ -2632,7 +2643,9 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy1, // Finally, with the direction of the edge being updated // identified, we update the on-disk edge representation. - err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode) + err = putChanEdgePolicy( + edges, nodes, &edge.ChannelEdgePolicy1, fromNode, toNode, + ) if err != nil { return false, err } @@ -2787,7 +2800,8 @@ func (l *LightningNode) isPublic(db kvdb.Backend, tx kvdb.RTx, nodeIsPublic := false errDone := errors.New("done") err := l.ForEachChannel(db, tx, func(_ kvdb.Backend, _ kvdb.RTx, - info models.ChannelEdgeInfo, _, _ *ChannelEdgePolicy1) error { + info models.ChannelEdgeInfo, _, + _ *ChannelEdgePolicyWithNode) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is @@ -2898,7 +2912,8 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { + *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], db, cb) } @@ -2959,7 +2974,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { + *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) error) error { traversal := func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) @@ -3070,7 +3086,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // traversal. func (l *LightningNode) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { + *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) error) error { nodePub := l.PubKeyBytes[:] @@ -3591,6 +3608,12 @@ func (c *ChannelAuthProof1) IsEmpty() bool { len(c.BitcoinSig2Bytes) == 0 } +type ChannelEdgePolicyWithNode struct { + ChannelEdgePolicy1 + + Node *LightningNode +} + // ChannelEdgePolicy1 represents a *directed* edge within the channel graph. For // each channel in the database, there are two distinct edges: one for each // possible direction of travel along the channel. The edges themselves hold @@ -3644,10 +3667,6 @@ type ChannelEdgePolicy1 struct { // HTLCs for each millionth of a satoshi forwarded. FeeProportionalMillionths lnwire.MilliSatoshi - // Node is the LightningNode that this directed edge leads to. Using - // this pointer the channel graph can further be traversed. - Node *LightningNode - // ExtraOpaqueData is the set of data that was appended to this // message, some of which we may not actually know how to iterate or // parse. By holding onto this data, we ensure that we're able to @@ -3714,23 +3733,25 @@ func (c *ChannelEdgePolicy1) ComputeFeeFromIncoming( ) } -func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) (*ChannelEdgePolicy1, - error) { +func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) ( + *ChannelEdgePolicyWithNode, error) { switch upd := update.(type) { case *lnwire.ChannelUpdate1: //nolint:lll - return &ChannelEdgePolicy1{ - SigBytes: upd.Signature.ToSignatureBytes(), - ChannelID: upd.ShortChannelID.ToUint64(), - LastUpdate: time.Unix(int64(upd.Timestamp), 0), - MessageFlags: upd.MessageFlags, - ChannelFlags: upd.ChannelFlags, - TimeLockDelta: upd.TimeLockDelta, - MinHTLC: upd.HtlcMinimumMsat, - MaxHTLC: upd.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + return &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + SigBytes: upd.Signature.ToSignatureBytes(), + ChannelID: upd.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(upd.Timestamp), 0), + MessageFlags: upd.MessageFlags, + ChannelFlags: upd.ChannelFlags, + TimeLockDelta: upd.TimeLockDelta, + MinHTLC: upd.HtlcMinimumMsat, + MaxHTLC: upd.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + }, }, nil default: return nil, fmt.Errorf("unhandled implementation of "+ @@ -3744,12 +3765,13 @@ func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) (*ChannelEdgePolicy1, // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (models.ChannelEdgeInfo, *ChannelEdgePolicy1, *ChannelEdgePolicy1, error) { +) (models.ChannelEdgeInfo, *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *ChannelEdgePolicy1 - policy2 *ChannelEdgePolicy1 + policy1 *ChannelEdgePolicyWithNode + policy2 *ChannelEdgePolicyWithNode ) err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -3829,13 +3851,14 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, // ErrZombieEdge an be returned if the edge is currently marked as a zombie // within the database. In this case, the ChannelEdgePolicy1's will be nil, and // the ChannelEdgeInfo will only include the public keys of each node. -func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (models.ChannelEdgeInfo, *ChannelEdgePolicy1, *ChannelEdgePolicy1, error) { +func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( + models.ChannelEdgeInfo, *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *ChannelEdgePolicy1 - policy2 *ChannelEdgePolicy1 + policy1 *ChannelEdgePolicyWithNode + policy2 *ChannelEdgePolicyWithNode channelID [8]byte ) @@ -4882,7 +4905,8 @@ func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64, } func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, - nodePub []byte, nodes kvdb.RBucket) (*ChannelEdgePolicy1, error) { + nodePub []byte, nodes kvdb.RBucket) (*ChannelEdgePolicyWithNode, + error) { var edgeKey [33 + 8]byte copy(edgeKey[:], nodePub) @@ -4915,8 +4939,8 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, - nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicy1, - *ChannelEdgePolicy1, error) { + nodes kvdb.RBucket, chanID []byte) (*ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode, error) { edgeInfoBytes := edgeIndex.Get(chanID) if edgeInfoBytes == nil { @@ -5016,11 +5040,11 @@ func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy1, } func deserializeChanEdgePolicy(r io.Reader, - nodes kvdb.RBucket) (*ChannelEdgePolicy1, error) { + nodes kvdb.RBucket) (*ChannelEdgePolicyWithNode, error) { // Deserialize the policy. Note that in case an optional field is not // found, both an error and a populated policy object are returned. - edge, deserializeErr := deserializeChanEdgePolicyRaw(r) + edge, pubKeyBytes, deserializeErr := deserializeChanEdgePolicyRaw(r) if deserializeErr != nil && deserializeErr != ErrEdgePolicyOptionalFieldNotFound { @@ -5028,68 +5052,71 @@ func deserializeChanEdgePolicy(r io.Reader, } // Populate full LightningNode struct. - pub := edge.Node.PubKeyBytes[:] - node, err := fetchLightningNode(nodes, pub) + node, err := fetchLightningNode(nodes, pubKeyBytes) if err != nil { - return nil, fmt.Errorf("unable to fetch node: %x, %v", pub, err) + return nil, fmt.Errorf("unable to fetch node: %x, %v", + pubKeyBytes, err) + } + + policy := ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: *edge, + Node: &node, } - edge.Node = &node - return edge, deserializeErr + return &policy, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { +func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, []byte, + error) { + edge := &ChannelEdgePolicy1{} var err error edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") if err != nil { - return nil, err + return nil, nil, err } if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil { - return nil, err + return nil, nil, err } var scratch [8]byte if _, err := r.Read(scratch[:]); err != nil { - return nil, err + return nil, nil, err } unix := int64(byteOrder.Uint64(scratch[:])) edge.LastUpdate = time.Unix(unix, 0) if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil { - return nil, err + return nil, nil, err } if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil { - return nil, err + return nil, nil, err } if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil { - return nil, err + return nil, nil, err } var n uint64 if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err + return nil, nil, err } edge.MinHTLC = lnwire.MilliSatoshi(n) if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err + return nil, nil, err } edge.FeeBaseMSat = lnwire.MilliSatoshi(n) if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err + return nil, nil, err } edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n) var pub [33]byte if _, err := r.Read(pub[:]); err != nil { - return nil, err - } - edge.Node = &LightningNode{ - PubKeyBytes: pub, + return nil, nil, err } // We'll try and see if there are any opaque bytes left, if not, then @@ -5101,7 +5128,7 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { case err == io.ErrUnexpectedEOF: case err == io.EOF: case err != nil: - return nil, err + return nil, nil, err } // See if optional fields are present. @@ -5114,7 +5141,7 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { // stored before this field was validated. We'll return the // edge along with an error. if len(opq) < 8 { - return edge, ErrEdgePolicyOptionalFieldNotFound + return edge, pub[:], ErrEdgePolicyOptionalFieldNotFound } maxHtlc := byteOrder.Uint64(opq[:8]) @@ -5124,7 +5151,7 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, error) { edge.ExtraOpaqueData = opq[8:] } - return edge, nil + return edge, pub[:], nil } const ( diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index c52ba8f1d6..8a460bb5c4 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -29,7 +29,8 @@ type GraphCacheNode interface { // to the caller. ForEachChannel(kvdb.Backend, kvdb.RTx, func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error + *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) error) error } // CachedEdgePolicy is a struct that only caches the information of a @@ -105,7 +106,7 @@ func (c *CachedEdgePolicy) ComputeFeeFromIncoming( } // NewCachedPolicy turns a full policy into a minimal one that can be cached. -func NewCachedPolicy(policy *ChannelEdgePolicy1) *CachedEdgePolicy { +func NewCachedPolicy(policy *ChannelEdgePolicyWithNode) *CachedEdgePolicy { return &CachedEdgePolicy{ ChannelID: policy.ChannelID, MessageFlags: policy.MessageFlags, @@ -224,8 +225,8 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) return node.ForEachChannel(nil, tx, func(_ kvdb.Backend, tx kvdb.RTx, - info models.ChannelEdgeInfo, outPolicy *ChannelEdgePolicy1, - inPolicy *ChannelEdgePolicy1) error { + info models.ChannelEdgeInfo, outPolicy, + inPolicy *ChannelEdgePolicyWithNode) error { c.AddChannel(info, outPolicy, inPolicy) @@ -238,7 +239,7 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { // and policy flags automatically. The policy will be set as the outgoing policy // on one node and the incoming policy on the peer's side. func (c *GraphCache) AddChannel(info models.ChannelEdgeInfo, - policy1 *ChannelEdgePolicy1, policy2 *ChannelEdgePolicy1) { + policy1, policy2 *ChannelEdgePolicyWithNode) { if info == nil { return @@ -300,7 +301,7 @@ func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { // of the from and to node is not strictly important. But we assume that a // channel edge was added beforehand so that the directed channel struct already // exists in the cache. -func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy1, fromNode, +func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicyWithNode, fromNode, toNode route.Vertex, edge1 bool) { c.mtx.Lock() diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index a0ab701940..4f3145cfd7 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -30,8 +30,8 @@ type node struct { features *lnwire.FeatureVector edgeInfos []*ChannelEdgeInfo1 - outPolicies []*ChannelEdgePolicy1 - inPolicies []*ChannelEdgePolicy1 + outPolicies []*ChannelEdgePolicyWithNode + inPolicies []*ChannelEdgePolicyWithNode } func (n *node) PubKey() route.Vertex { @@ -43,7 +43,8 @@ func (n *node) Features() *lnwire.FeatureVector { func (n *node) ForEachChannel(db kvdb.Backend, tx kvdb.RTx, cb func(kvdb.Backend, kvdb.RTx, models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error) error { + *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) error) error { for idx := range n.edgeInfos { err := cb( @@ -71,17 +72,25 @@ func TestGraphCacheAddNode(t *testing.T) { channelFlagA, channelFlagB = 1, 0 } - outPolicy1 := &ChannelEdgePolicy1{ - ChannelID: 1000, - ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), + outPolicy1 := &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + ChannelID: 1000, + ChannelFlags: lnwire.ChanUpdateChanFlags( + channelFlagA, + ), + }, Node: &LightningNode{ PubKeyBytes: nodeB, Features: lnwire.EmptyFeatureVector(), }, } - inPolicy1 := &ChannelEdgePolicy1{ - ChannelID: 1000, - ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), + inPolicy1 := &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + ChannelID: 1000, + ChannelFlags: lnwire.ChanUpdateChanFlags( + channelFlagB, + ), + }, Node: &LightningNode{ PubKeyBytes: nodeA, Features: lnwire.EmptyFeatureVector(), @@ -97,8 +106,8 @@ func TestGraphCacheAddNode(t *testing.T) { NodeKey2Bytes: pubKey2, Capacity: 500, }}, - outPolicies: []*ChannelEdgePolicy1{outPolicy1}, - inPolicies: []*ChannelEdgePolicy1{inPolicy1}, + outPolicies: []*ChannelEdgePolicyWithNode{outPolicy1}, + inPolicies: []*ChannelEdgePolicyWithNode{inPolicy1}, } cache := NewGraphCache(10) require.NoError(t, cache.AddNode(nil, node)) @@ -145,7 +154,7 @@ func TestGraphCacheAddNode(t *testing.T) { runTest(pubKey2, pubKey1) } -func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy1, +func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicyWithNode, cached *CachedEdgePolicy) { require.Equal(t, original.ChannelID, cached.ChannelID) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 260a169980..03f7076108 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -692,7 +692,8 @@ func assertEdgeInfo1Equal(t *testing.T, e1 *ChannelEdgeInfo1, } func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( - *ChannelEdgeInfo1, *ChannelEdgePolicy1, *ChannelEdgePolicy1) { + *ChannelEdgeInfo1, *ChannelEdgePolicyWithNode, + *ChannelEdgePolicyWithNode) { var ( firstNode *LightningNode @@ -734,33 +735,39 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) - edge1 := &ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(433453, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 3452352, - Node: secondNode, - ExtraOpaqueData: []byte("new unknown feature2"), - } - edge2 := &ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(124234, 0), - MessageFlags: 1, - ChannelFlags: 1, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 90392423, - Node: firstNode, - ExtraOpaqueData: []byte("new unknown feature1"), + edge1 := &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(433453, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 3452352, + ExtraOpaqueData: []byte("new unknown feature2"), + }, + Node: secondNode, + } + edge2 := &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(124234, 0), + MessageFlags: 1, + ChannelFlags: 1, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 90392423, + ExtraOpaqueData: []byte( + "new unknown feature1", + ), + }, + Node: firstNode, } return edgeInfo, edge1, edge2 @@ -976,7 +983,7 @@ func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { } func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, - e models.ChannelEdgeInfo, p *ChannelEdgePolicy1, policy1 bool) { + e models.ChannelEdgeInfo, p *ChannelEdgePolicyWithNode, policy1 bool) { var ( node1Bytes = e.Node1Bytes() @@ -1058,25 +1065,27 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, } } -func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy1 { +func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicyWithNode { update := prand.Int63() return newEdgePolicy(chanID, db, update) } func newEdgePolicy(chanID uint64, db kvdb.Backend, - updateTime int64) *ChannelEdgePolicy1 { - - return &ChannelEdgePolicy1{ - ChannelID: chanID, - LastUpdate: time.Unix(updateTime, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: uint16(prand.Int63()), - MinHTLC: lnwire.MilliSatoshi(prand.Int63()), - MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), - FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), - FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), + updateTime int64) *ChannelEdgePolicyWithNode { + + return &ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: ChannelEdgePolicy1{ + ChannelID: chanID, + LastUpdate: time.Unix(updateTime, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: uint16(prand.Int63()), + MinHTLC: lnwire.MilliSatoshi(prand.Int63()), + MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), + FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), + FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), + }, } } @@ -1125,7 +1134,8 @@ func TestGraphTraversal(t *testing.T) { // again if the map is empty that indicates that all edges have // properly been reached. err = graph.ForEachChannel(func(ei models.ChannelEdgeInfo, - _ *ChannelEdgePolicy1, _ *ChannelEdgePolicy1) error { + _ *ChannelEdgePolicyWithNode, + _ *ChannelEdgePolicyWithNode) error { delete(chanIndex, ei.GetChanID()) return nil @@ -1139,7 +1149,7 @@ func TestGraphTraversal(t *testing.T) { firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, _ models.ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy1) error { + outEdge, inEdge *ChannelEdgePolicyWithNode) error { // All channels between first and second node should // have fully (both sides) specified policies. @@ -1222,8 +1232,8 @@ func TestGraphTraversalCacheable(t *testing.T) { err := node.ForEachChannel( graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, info models.ChannelEdgeInfo, - _ *ChannelEdgePolicy1, - _ *ChannelEdgePolicy1) error { + _ *ChannelEdgePolicyWithNode, + _ *ChannelEdgePolicyWithNode) error { delete(chanIndex, info.GetChanID()) @@ -1407,7 +1417,7 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 if err := graph.ForEachChannel(func(models.ChannelEdgeInfo, - *ChannelEdgePolicy1, *ChannelEdgePolicy1) error { + *ChannelEdgePolicyWithNode, *ChannelEdgePolicyWithNode) error { numChans++ return nil @@ -2382,7 +2392,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, _ models.ChannelEdgeInfo, outEdge, - inEdge *ChannelEdgePolicy1) error { + inEdge *ChannelEdgePolicyWithNode) error { if !expectedOut && outEdge != nil { t.Fatalf("Expected no outgoing policy") @@ -3014,7 +3024,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { edge1.ExtraOpaqueData = nil var b bytes.Buffer - err = serializeChanEdgePolicy(&b, edge1, to) + err = serializeChanEdgePolicy(&b, &edge1.ChannelEdgePolicy1, to) if err != nil { t.Fatalf("unable to serialize policy") } @@ -3024,7 +3034,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { edge1.MessageFlags = lnwire.ChanUpdateRequiredMaxHtlc edge1.MaxHTLC = 13928598 var b2 bytes.Buffer - err = serializeChanEdgePolicy(&b2, edge1, to) + err = serializeChanEdgePolicy(&b2, &edge1.ChannelEdgePolicy1, to) if err != nil { t.Fatalf("unable to serialize policy") } @@ -3251,7 +3261,7 @@ func compareNodes(a, b *LightningNode) error { // compareEdgePolicies is used to compare two ChannelEdgePolices using // compareNodes, so as to exclude comparisons of the Nodes' Features struct. -func compareEdgePolicies(a, b *ChannelEdgePolicy1) error { +func compareEdgePolicies(a, b *ChannelEdgePolicyWithNode) error { if a.ChannelID != b.ChannelID { return fmt.Errorf("ChannelID doesn't match: expected %v, "+ "got %v", a.ChannelID, b.ChannelID) @@ -3474,7 +3484,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { errTimeout := errors.New("timeout adding batched channel") - updates := []*ChannelEdgePolicy1{edge1, edge2} + updates := []*ChannelEdgePolicyWithNode{edge1, edge2} errChan := make(chan error, len(updates)) @@ -3482,7 +3492,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { var wg sync.WaitGroup for _, update := range updates { wg.Add(1) - go func(update *ChannelEdgePolicy1) { + go func(update *ChannelEdgePolicyWithNode) { defer wg.Done() select { @@ -3534,8 +3544,8 @@ func BenchmarkForEachChannel(b *testing.B) { graph.db, tx, func(_ kvdb.Backend, _ kvdb.RTx, info models.ChannelEdgeInfo, - policy *ChannelEdgePolicy1, - policy2 *ChannelEdgePolicy1) error { + policy *ChannelEdgePolicyWithNode, + policy2 *ChannelEdgePolicyWithNode) error { // We need to do something with // the data here, otherwise the diff --git a/discovery/chan_series.go b/discovery/chan_series.go index a8cdcd8ce9..fc340a9885 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -329,7 +329,9 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, chanUpdates := make([]*lnwire.ChannelUpdate1, 0, 2) if e1 != nil { - chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e1) + chanUpdate, err := netann.ChannelUpdateFromEdge( + chanInfo, e1, + ) if err != nil { return nil, err } @@ -337,7 +339,9 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, chanUpdates = append(chanUpdates, chanUpdate) } if e2 != nil { - chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e2) + chanUpdate, err := netann.ChannelUpdateFromEdge( + chanInfo, e2, + ) if err != nil { return nil, err } diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 209d58fadb..941964865c 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -531,7 +531,7 @@ type EdgeWithInfo struct { Info models.ChannelEdgeInfo // Edge describes the policy in one direction of the channel. - Edge *channeldb.ChannelEdgePolicy1 + Edge *channeldb.ChannelEdgePolicyWithNode } // PropagateChanPolicyUpdate signals the AuthenticatedGossiper to perform the @@ -1581,7 +1581,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // within the prune interval or re-broadcast interval. type updateTuple struct { info models.ChannelEdgeInfo - edge *channeldb.ChannelEdgePolicy1 + edge *channeldb.ChannelEdgePolicyWithNode } var ( @@ -1590,7 +1590,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { ) err := d.cfg.Router.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy1) error { + edge *channeldb.ChannelEdgePolicyWithNode) error { // If there's no auth proof attached to this edge, it means // that it is a private channel not meant to be announced to @@ -2131,7 +2131,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // Otherwise, we'll retrieve the correct policy that we // currently have stored within our graph to check if this // message is stale by comparing its timestamp. - var p *channeldb.ChannelEdgePolicy1 + var p *channeldb.ChannelEdgePolicyWithNode if msg.ChannelFlags&lnwire.ChanUpdateDirection == 0 { p = p1 } else { @@ -2157,7 +2157,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, + edge *channeldb.ChannelEdgePolicyWithNode) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, error) { // Parse the unsigned edge into a channel update. @@ -2836,10 +2836,14 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, switch direction { case 0: pubKey, _ = chanInfo.NodeKey1() - edgeToUpdate = e1 + if e1 != nil { + edgeToUpdate = &e1.ChannelEdgePolicy1 + } case 1: pubKey, _ = chanInfo.NodeKey2() - edgeToUpdate = e2 + if e2 != nil { + edgeToUpdate = &e2.ChannelEdgePolicy1 + } } log.Debugf("Validating ChannelUpdate1: channel=%v, from node=%x, has "+ @@ -2921,18 +2925,20 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // different alias. This might mean that SigBytes is incorrect as it // signs a different SCID than the database SCID, but since there will // only be a difference if AuthProof == nil, this is fine. - update := &channeldb.ChannelEdgePolicy1{ - SigBytes: upd.Signature.ToSignatureBytes(), - ChannelID: chanInfo.GetChanID(), - LastUpdate: timestamp, - MessageFlags: upd.MessageFlags, - ChannelFlags: upd.ChannelFlags, - TimeLockDelta: upd.TimeLockDelta, - MinHTLC: upd.HtlcMinimumMsat, - MaxHTLC: upd.HtlcMaximumMsat, - FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), - ExtraOpaqueData: upd.ExtraOpaqueData, + update := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: upd.Signature.ToSignatureBytes(), + ChannelID: chanInfo.GetChanID(), + LastUpdate: timestamp, + MessageFlags: upd.MessageFlags, + ChannelFlags: upd.ChannelFlags, + TimeLockDelta: upd.TimeLockDelta, + MinHTLC: upd.HtlcMinimumMsat, + MaxHTLC: upd.HtlcMaximumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(upd.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(upd.FeeRate), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, } if err := d.cfg.Router.UpdateEdge(update, ops...); err != nil { diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 305d2eb0cd..d880c74358 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -93,7 +93,7 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]channeldb.ChannelEdgePolicy1 + edges map[uint64][]channeldb.ChannelEdgePolicyWithNode zombies map[uint64][][33]byte chansToReject map[uint64]struct{} } @@ -102,7 +102,7 @@ func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, infos: make(map[uint64]models.ChannelEdgeInfo), - edges: make(map[uint64][]channeldb.ChannelEdgePolicy1), + edges: make(map[uint64][]channeldb.ChannelEdgePolicyWithNode), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), } @@ -146,14 +146,16 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.chansToReject[chanID] = struct{}{} } -func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicy1, +func (r *mockGraphSource) UpdateEdge(edge *channeldb.ChannelEdgePolicyWithNode, _ ...batch.SchedulerOption) error { r.mu.Lock() defer r.mu.Unlock() if len(r.edges[edge.ChannelID]) == 0 { - r.edges[edge.ChannelID] = make([]channeldb.ChannelEdgePolicy1, 2) + r.edges[edge.ChannelID] = make( + []channeldb.ChannelEdgePolicyWithNode, 2, + ) } if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { @@ -198,7 +200,8 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) } func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i models.ChannelEdgeInfo, c *channeldb.ChannelEdgePolicy1) error) error { + i models.ChannelEdgeInfo, + c *channeldb.ChannelEdgePolicyWithNode) error) error { r.mu.Lock() defer r.mu.Unlock() @@ -235,8 +238,8 @@ func (r *mockGraphSource) ForEachChannel(_ func(chanInfo models.ChannelEdgeInfo, } func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, - *channeldb.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) { r.mu.Lock() defer r.mu.Unlock() @@ -262,13 +265,13 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return chanInfoCP, nil, nil, nil } - var edge1 *channeldb.ChannelEdgePolicy1 - if !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicy1{}) { + var edge1 *channeldb.ChannelEdgePolicyWithNode + if !reflect.DeepEqual(edges[0], channeldb.ChannelEdgePolicyWithNode{}) { edge1 = &edges[0] } - var edge2 *channeldb.ChannelEdgePolicy1 - if !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicy1{}) { + var edge2 *channeldb.ChannelEdgePolicyWithNode + if !reflect.DeepEqual(edges[1], channeldb.ChannelEdgePolicyWithNode{}) { edge2 = &edges[1] } @@ -3456,7 +3459,7 @@ out: err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy1) error { + edge *channeldb.ChannelEdgePolicyWithNode) error { edge.TimeLockDelta = uint16(newTimeLockDelta) edgesToUpdate = append(edgesToUpdate, EdgeWithInfo{ diff --git a/funding/manager.go b/funding/manager.go index 65b90ec7ba..c94a597059 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -524,7 +524,7 @@ type Config struct { // DeleteAliasEdge allows the Manager to delete an alias channel edge // from the graph. It also returns our local to-be-deleted policy. DeleteAliasEdge func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy1, error) + *channeldb.ChannelEdgePolicyWithNode, error) // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. @@ -3524,7 +3524,8 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, } err = f.addToRouterGraph( - completeChan, &baseScid, nil, ourPolicy, + completeChan, &baseScid, nil, + &ourPolicy.ChannelEdgePolicy1, ) if err != nil { return fmt.Errorf("failed to re-add to "+ @@ -3615,7 +3616,8 @@ func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel, // alias since we'll be using the confirmed SCID from now on // regardless if it's public or not. err = f.addToRouterGraph( - c, &confChan.shortChanID, nil, ourPolicy, + c, &confChan.shortChanID, nil, + &ourPolicy.ChannelEdgePolicy1, ) if err != nil { return fmt.Errorf("failed adding confirmed zero-conf "+ diff --git a/funding/manager_test.go b/funding/manager_test.go index 8d7f2e95d6..0357d877b4 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -550,7 +550,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, OpenChannelPredicate: chainedAcceptor, NotifyPendingOpenChannelEvent: evt.NotifyPendingOpenChannelEvent, DeleteAliasEdge: func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy1, error) { + *channeldb.ChannelEdgePolicyWithNode, error) { return nil, nil }, diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 0bb27d0237..2986a91522 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -288,7 +288,7 @@ func (s *Server) ImportGraph(ctx context.Context, rpcEdge.ChanPoint, err) } - makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *channeldb.ChannelEdgePolicy1 { + makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *channeldb.ChannelEdgePolicyWithNode { policy := &channeldb.ChannelEdgePolicy1{ ChannelID: rpcEdge.ChannelId, LastUpdate: time.Unix( @@ -315,7 +315,9 @@ func (s *Server) ImportGraph(ctx context.Context, lnwire.ChanUpdateRequiredMaxHtlc } - return policy + return &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: *policy, + } } if rpcEdge.Node1Policy != nil { diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 7648cc7a20..430b2b4276 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -490,7 +490,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // chanCanBeHopHint returns true if the target channel is eligible to be a hop // hint. func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( - *channeldb.ChannelEdgePolicy1, bool) { + *channeldb.ChannelEdgePolicyWithNode, bool) { // Since we're only interested in our private channels, we'll skip // public ones. @@ -546,7 +546,7 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( // Now, we'll need to determine which is the correct policy for HTLCs // being sent from the remote node. var ( - remotePolicy *channeldb.ChannelEdgePolicy1 + remotePolicy *channeldb.ChannelEdgePolicyWithNode node1Bytes = info.Node1Bytes() ) if bytes.Equal(remotePub[:], node1Bytes[:]) { @@ -632,8 +632,8 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. FetchChannelEdgesByID func(chanID uint64) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, - error) + *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) // GetAlias allows the peer's alias SCID to be retrieved for private // option_scid_alias channels. @@ -760,7 +760,8 @@ func shouldIncludeChannel(cfg *SelectHopHintsCfg, // Now that we know this channel use usable, add it as a hop hint and // the indexes we'll use later. - hopHint := newHopHint(hopHintInfo, edgePolicy) + hopHint := newHopHint(hopHintInfo, &edgePolicy.ChannelEdgePolicy1) + return hopHint, hopHintInfo.RemoteBalance, true } diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index af038b9415..feac9b270a 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -52,8 +52,8 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, - *channeldb.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) { args := h.Mock.Called(chanID) @@ -71,8 +71,8 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( "ChannelEdgeInfo impl received: %T", args.Get(0)) } - policy1 := args.Get(1).(*channeldb.ChannelEdgePolicy1) - policy2 := args.Get(2).(*channeldb.ChannelEdgePolicy1) + policy1 := args.Get(1).(*channeldb.ChannelEdgePolicyWithNode) + policy2 := args.Get(2).(*channeldb.ChannelEdgePolicyWithNode) return edgeInfo, policy1, policy2, err } @@ -222,8 +222,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) h.Mock.On( @@ -260,8 +260,8 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) alias := lnwire.ShortChannelID{TxPosition: 5} h.Mock.On( @@ -303,12 +303,14 @@ var shouldIncludeChannelTestCases = []struct { &channeldb.ChannelEdgeInfo1{ NodeKey1Bytes: selectedPolicy, }, - &channeldb.ChannelEdgePolicy1{ - FeeBaseMSat: 1000, - FeeProportionalMillionths: 20, - TimeLockDelta: 13, + &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, + }, }, - &channeldb.ChannelEdgePolicy1{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, @@ -349,11 +351,13 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{ - FeeBaseMSat: 1000, - FeeProportionalMillionths: 20, - TimeLockDelta: 13, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, + }, }, nil, ) }, @@ -394,11 +398,13 @@ var shouldIncludeChannelTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{ - FeeBaseMSat: 1000, - FeeProportionalMillionths: 20, - TimeLockDelta: 13, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, + }, }, nil, ) @@ -561,8 +567,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 1, @@ -611,8 +617,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 10, @@ -662,8 +668,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 1, @@ -695,8 +701,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) // Prepare the mock for the second channel. @@ -712,8 +718,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 10, @@ -749,8 +755,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) // Prepare the mock for the second channel. @@ -766,8 +772,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 10, @@ -804,8 +810,8 @@ var populateHopHintsTestCases = []struct { "FetchChannelEdgesByID", mock.Anything, ).Once().Return( &channeldb.ChannelEdgeInfo1{}, - &channeldb.ChannelEdgePolicy1{}, - &channeldb.ChannelEdgePolicy1{}, nil, + &channeldb.ChannelEdgePolicyWithNode{}, + &channeldb.ChannelEdgePolicyWithNode{}, nil, ) }, maxHopHints: 1, diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 3bed4b0a51..39d09ff616 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -67,8 +67,8 @@ func createChannel(t *testing.T) *channeldb.OpenChannel { // update will be created with the disabled bit set if startEnabled is false. func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, pubkey *btcec.PublicKey, startEnabled bool) ( - *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicy1, - *channeldb.ChannelEdgePolicy1) { + *channeldb.ChannelEdgeInfo1, *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode) { var ( pubkey1 [33]byte @@ -105,17 +105,21 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, NodeKey1Bytes: pubkey1, NodeKey2Bytes: pubkey2, }, - &channeldb.ChannelEdgePolicy1{ - ChannelID: channel.ShortChanID().ToUint64(), - ChannelFlags: dir1, - LastUpdate: time.Now(), - SigBytes: testSigBytes, + &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + ChannelID: channel.ShortChanID().ToUint64(), + ChannelFlags: dir1, + LastUpdate: time.Now(), + SigBytes: testSigBytes, + }, }, - &channeldb.ChannelEdgePolicy1{ - ChannelID: channel.ShortChanID().ToUint64(), - ChannelFlags: dir2, - LastUpdate: time.Now(), - SigBytes: testSigBytes, + &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + ChannelID: channel.ShortChanID().ToUint64(), + ChannelFlags: dir2, + LastUpdate: time.Now(), + SigBytes: testSigBytes, + }, } } @@ -123,8 +127,8 @@ type mockGraph struct { mu sync.Mutex channels []*channeldb.OpenChannel chanInfos map[wire.OutPoint]*channeldb.ChannelEdgeInfo1 - chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicy1 - chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy1 + chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicyWithNode + chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicyWithNode sidToCid map[lnwire.ShortChannelID]wire.OutPoint updates chan *lnwire.ChannelUpdate1 @@ -136,8 +140,8 @@ func newMockGraph(t *testing.T, numChannels int, g := &mockGraph{ channels: make([]*channeldb.OpenChannel, 0, numChannels), chanInfos: make(map[wire.OutPoint]*channeldb.ChannelEdgeInfo1), - chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy1), - chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy1), + chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicyWithNode), + chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicyWithNode), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), } @@ -162,7 +166,8 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { func (g *mockGraph) FetchChannelEdgesByOutpoint( op *wire.OutPoint) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, error) { + *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) { g.mu.Lock() defer g.mu.Unlock() @@ -211,11 +216,13 @@ func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, timestamp := time.Unix(int64(update.Timestamp), 0) - policy := &channeldb.ChannelEdgePolicy1{ - ChannelID: update.ShortChannelID.ToUint64(), - ChannelFlags: update.ChannelFlags, - LastUpdate: timestamp, - SigBytes: testSigBytes, + policy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + ChannelID: update.ShortChannelID.ToUint64(), + ChannelFlags: update.ChannelFlags, + LastUpdate: timestamp, + SigBytes: testSigBytes, + }, } if update1 { @@ -250,7 +257,7 @@ func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, info *channeldb.ChannelEdgeInfo1, - pol1, pol2 *channeldb.ChannelEdgePolicy1) { + pol1, pol2 *channeldb.ChannelEdgePolicyWithNode) { g.mu.Lock() defer g.mu.Unlock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index db870d0568..c48cca9dc4 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -16,8 +16,9 @@ import ( // peer's initial routing table upon connect. func CreateChanAnnouncement(chanProof models.ChannelAuthProof, chanInfo models.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { + e1, e2 *channeldb.ChannelEdgePolicyWithNode) ( + lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, + *lnwire.ChannelUpdate1, error) { switch proof := chanProof.(type) { case *channeldb.ChannelAuthProof1: @@ -38,8 +39,9 @@ func CreateChanAnnouncement(chanProof models.ChannelAuthProof, func createChanAnnouncement1(chanProof *channeldb.ChannelAuthProof1, chanInfo *channeldb.ChannelEdgeInfo1, - e1, e2 *channeldb.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { + e1, e2 *channeldb.ChannelEdgePolicyWithNode) ( + lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, + *lnwire.ChannelUpdate1, error) { // First, using the parameters of the channel, along with the channel // authentication chanProof, we'll create re-create the original diff --git a/netann/channel_update.go b/netann/channel_update.go index 7302fdf561..3e0b971cd5 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -87,11 +87,11 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, info models.ChannelEdgeInfo, - policies ...*channeldb.ChannelEdgePolicy1) ( + policies ...*channeldb.ChannelEdgePolicyWithNode) ( *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. - owner := func(edge *channeldb.ChannelEdgePolicy1) []byte { + owner := func(edge *channeldb.ChannelEdgePolicyWithNode) []byte { var pubKey *btcec.PublicKey if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { pubKey, _ = info.NodeKey1() @@ -120,7 +120,7 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate1 from the // given edge info and policy. func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, - policy *channeldb.ChannelEdgePolicy1) *lnwire.ChannelUpdate1 { + policy *channeldb.ChannelEdgePolicyWithNode) *lnwire.ChannelUpdate1 { return &lnwire.ChannelUpdate1{ ChainHash: chainHash, @@ -140,7 +140,8 @@ func UnsignedChannelUpdateFromEdge(chainHash chainhash.Hash, // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate1 from the given edge // info and policy. func ChannelUpdateFromEdge(info models.ChannelEdgeInfo, - policy *channeldb.ChannelEdgePolicy1) (*lnwire.ChannelUpdate1, error) { + policy *channeldb.ChannelEdgePolicyWithNode) (*lnwire.ChannelUpdate1, + error) { update := UnsignedChannelUpdateFromEdge(info.GetChainHash(), policy) diff --git a/netann/interface.go b/netann/interface.go index 91b1eb7196..f60dfee9ac 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -20,5 +20,6 @@ type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. FetchChannelEdgesByOutpoint(*wire.OutPoint) (models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy1, *channeldb.ChannelEdgePolicy1, error) + *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) } diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 39e799043e..0d72451841 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -32,7 +32,7 @@ type Manager struct { // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy1) error) error + *channeldb.ChannelEdgePolicyWithNode) error) error // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. @@ -74,7 +74,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, info models.ChannelEdgeInfo, - edge *channeldb.ChannelEdgePolicy1) error { + edge *channeldb.ChannelEdgePolicyWithNode) error { chanPoint := info.GetChanPoint() @@ -174,7 +174,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // updateEdge updates the given edge with the new schema. func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint, - edge *channeldb.ChannelEdgePolicy1, + edge *channeldb.ChannelEdgePolicyWithNode, newSchema routing.ChannelPolicy) error { // Update forwarding fee scheme and required time lock delta. diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index 9bd9528d0d..1f484892d8 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -44,9 +44,11 @@ func TestManager(t *testing.T) { MaxHTLC: 5000, } - currentPolicy := channeldb.ChannelEdgePolicy1{ - MinHTLC: minHTLC, - MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + currentPolicy := channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + MinHTLC: minHTLC, + MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, + }, } updateForwardingPolicies := func( @@ -108,7 +110,7 @@ func TestManager(t *testing.T) { forAllOutgoingChannels := func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *channeldb.ChannelEdgePolicy1) error) error { + *channeldb.ChannelEdgePolicyWithNode) error) error { for _, c := range channelSet { if err := cb(nil, c.edgeInfo, ¤tPolicy); err != nil { @@ -152,7 +154,7 @@ func TestManager(t *testing.T) { tests := []struct { name string - currentPolicy channeldb.ChannelEdgePolicy1 + currentPolicy channeldb.ChannelEdgePolicyWithNode newPolicy routing.ChannelPolicy channelSet []channel specifiedChanPoints []wire.OutPoint diff --git a/routing/notifications.go b/routing/notifications.go index 6ee494299b..181477862f 100644 --- a/routing/notifications.go +++ b/routing/notifications.go @@ -339,7 +339,7 @@ func addToTopologyChange(graph *channeldb.ChannelGraph, update *TopologyChange, // Any new ChannelUpdateAnnouncements will generate a corresponding // ChannelEdgeUpdate notification. - case *channeldb.ChannelEdgePolicy1: + case *channeldb.ChannelEdgePolicyWithNode: // We'll need to fetch the edge's information from the database // in order to get the information concerning which nodes are // being connected. diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 21fa57fa97..1e69182697 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -74,18 +74,20 @@ func createTestNode() (*channeldb.LightningNode, error) { } func randEdgePolicy(chanID *lnwire.ShortChannelID, - node *channeldb.LightningNode) *channeldb.ChannelEdgePolicy1 { - - return &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: chanID.ToUint64(), - LastUpdate: time.Unix(int64(prand.Int31()), 0), - TimeLockDelta: uint16(prand.Int63()), - MinHTLC: lnwire.MilliSatoshi(prand.Int31()), - MaxHTLC: lnwire.MilliSatoshi(prand.Int31()), - FeeBaseMSat: lnwire.MilliSatoshi(prand.Int31()), - FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int31()), - Node: node, + node *channeldb.LightningNode) *channeldb.ChannelEdgePolicyWithNode { + + return &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: chanID.ToUint64(), + LastUpdate: time.Unix(int64(prand.Int31()), 0), + TimeLockDelta: uint16(prand.Int63()), + MinHTLC: lnwire.MilliSatoshi(prand.Int31()), + MaxHTLC: lnwire.MilliSatoshi(prand.Int31()), + FeeBaseMSat: lnwire.MilliSatoshi(prand.Int31()), + FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int31()), + }, + Node: node, } } @@ -455,7 +457,7 @@ func TestEdgeUpdateNotification(t *testing.T) { } assertEdgeCorrect := func(t *testing.T, edgeUpdate *ChannelEdgeUpdate, - edgeAnn *channeldb.ChannelEdgePolicy1) { + edgeAnn *channeldb.ChannelEdgePolicyWithNode) { if edgeUpdate.ChanID != edgeAnn.ChannelID { t.Fatalf("channel ID of edge doesn't match: "+ "expected %v, got %v", chanID.ToUint64(), edgeUpdate.ChanID) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index ce10db1311..8a1499d48b 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -376,17 +376,19 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( targetNode = edgeInfo.NodeKey2Bytes } - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), - ChannelFlags: channelFlags, - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: edge.Expiry, - MinHTLC: lnwire.MilliSatoshi(edge.MinHTLC), - MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), - FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), - FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), + ChannelFlags: channelFlags, + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: edge.Expiry, + MinHTLC: lnwire.MilliSatoshi(edge.MinHTLC), + MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), + FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), + FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + }, Node: &channeldb.LightningNode{ Alias: aliasForNode(targetNode), PubKeyBytes: targetNode, @@ -689,17 +691,19 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, node2Features = node2.Features } - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - MessageFlags: msgFlags, - ChannelFlags: channelFlags, - ChannelID: channelID, - LastUpdate: node1.LastUpdate, - TimeLockDelta: node1.Expiry, - MinHTLC: node1.MinHTLC, - MaxHTLC: node1.MaxHTLC, - FeeBaseMSat: node1.FeeBaseMsat, - FeeProportionalMillionths: node1.FeeRate, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + MessageFlags: msgFlags, + ChannelFlags: channelFlags, + ChannelID: channelID, + LastUpdate: node1.LastUpdate, + TimeLockDelta: node1.Expiry, + MinHTLC: node1.MinHTLC, + MaxHTLC: node1.MaxHTLC, + FeeBaseMSat: node1.FeeBaseMsat, + FeeProportionalMillionths: node1.FeeRate, + }, Node: &channeldb.LightningNode{ Alias: node2.Alias, PubKeyBytes: node2Vertex, @@ -727,17 +731,19 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, node1Features = node1.Features } - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - MessageFlags: msgFlags, - ChannelFlags: channelFlags, - ChannelID: channelID, - LastUpdate: node2.LastUpdate, - TimeLockDelta: node2.Expiry, - MinHTLC: node2.MinHTLC, - MaxHTLC: node2.MaxHTLC, - FeeBaseMSat: node2.FeeBaseMsat, - FeeProportionalMillionths: node2.FeeRate, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + MessageFlags: msgFlags, + ChannelFlags: channelFlags, + ChannelID: channelID, + LastUpdate: node2.LastUpdate, + TimeLockDelta: node2.Expiry, + MinHTLC: node2.MinHTLC, + MaxHTLC: node2.MaxHTLC, + FeeBaseMSat: node2.FeeBaseMsat, + FeeProportionalMillionths: node2.FeeRate, + }, Node: &channeldb.LightningNode{ Alias: node1.Alias, PubKeyBytes: node1Vertex, diff --git a/routing/router.go b/routing/router.go index 1c3511e444..990634b682 100644 --- a/routing/router.go +++ b/routing/router.go @@ -144,7 +144,7 @@ type ChannelGraphSource interface { // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. - UpdateEdge(policy *channeldb.ChannelEdgePolicy1, + UpdateEdge(policy *channeldb.ChannelEdgePolicyWithNode, op ...batch.SchedulerOption) error // IsStaleNode returns true if the graph source has a node announcement @@ -176,7 +176,7 @@ type ChannelGraphSource interface { // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, c models.ChannelEdgeInfo, - e *channeldb.ChannelEdgePolicy1) error) error + e *channeldb.ChannelEdgePolicyWithNode) error) error // CurrentBlockHeight returns the block height from POV of the router // subsystem. @@ -184,8 +184,8 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, - *channeldb.ChannelEdgePolicy1, error) + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) // FetchLightningNode attempts to look up a target node by its identity // public key. channeldb.ErrGraphNodeNotFound is returned if the node @@ -904,7 +904,7 @@ func (r *ChannelRouter) pruneZombieChans() error { // First, we'll collect all the channels which are eligible for garbage // collection due to being zombies. filterPruneChans := func(info models.ChannelEdgeInfo, - e1, e2 *channeldb.ChannelEdgePolicy1) error { + e1, e2 *channeldb.ChannelEdgePolicyWithNode) error { chanID := info.GetChanID() @@ -1654,7 +1654,7 @@ func (r *ChannelRouter) processUpdate(msg interface{}, "view: %v", err) } - case *channeldb.ChannelEdgePolicy1: + case *channeldb.ChannelEdgePolicyWithNode: log.Debugf("Received ChannelEdgePolicy1 for channel %v", msg.ChannelID) @@ -2682,7 +2682,7 @@ func (r *ChannelRouter) AddEdge(edge models.ChannelEdgeInfo, // considered as not fully constructed. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) UpdateEdge(update *channeldb.ChannelEdgePolicy1, +func (r *ChannelRouter) UpdateEdge(update *channeldb.ChannelEdgePolicyWithNode, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -2723,8 +2723,8 @@ func (r *ChannelRouter) SyncedHeight() uint32 { // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1, - *channeldb.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicyWithNode, + *channeldb.ChannelEdgePolicyWithNode, error) { return r.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) } @@ -2757,11 +2757,12 @@ func (r *ChannelRouter) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, - models.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy1) error) error { + models.ChannelEdgeInfo, + *channeldb.ChannelEdgePolicyWithNode) error) error { return r.selfNode.ForEachChannel(r.cfg.Graph.DB(), nil, func(_ kvdb.Backend, tx kvdb.RTx, c models.ChannelEdgeInfo, - e, _ *channeldb.ChannelEdgePolicy1) error { + e, _ *channeldb.ChannelEdgePolicyWithNode) error { if e == nil { return fmt.Errorf("channel from self node " + diff --git a/routing/router_test.go b/routing/router_test.go index 8b9c88f04c..917bba513f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1335,14 +1335,16 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { BitcoinKey2Bytes: pub2, AuthProof: nil, } - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, } // Attempt to update the edge. This should be ignored, since the edge @@ -1420,14 +1422,16 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // We must add the edge policy to be able to use the edge for route // finding. - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, Node: &channeldb.LightningNode{ PubKeyBytes: edge.NodeKey2Bytes, }, @@ -1439,14 +1443,16 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { } // Create edge in the other direction as well. - edgePolicy = &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy = &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, Node: &channeldb.LightningNode{ PubKeyBytes: edge.NodeKey1Bytes, }, @@ -1517,14 +1523,16 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to add edge to the channel graph: %v.", err) } - edgePolicy = &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy = &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, Node: &channeldb.LightningNode{ PubKeyBytes: edge.NodeKey2Bytes, }, @@ -1535,14 +1543,16 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to update edge policy: %v", err) } - edgePolicy = &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: testTime, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy = &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: testTime, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, Node: &channeldb.LightningNode{ PubKeyBytes: edge.NodeKey1Bytes, }, @@ -2650,28 +2660,32 @@ func TestIsStaleEdgePolicy(t *testing.T) { } // We'll also add two edge policies, one for each direction. - edgePolicy := &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: updateTimeStamp, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy := &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: updateTimeStamp, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, } edgePolicy.ChannelFlags = 0 if err := ctx.router.UpdateEdge(edgePolicy); err != nil { t.Fatalf("unable to update edge policy: %v", err) } - edgePolicy = &channeldb.ChannelEdgePolicy1{ - SigBytes: testSig.Serialize(), - ChannelID: edge.ChannelID, - LastUpdate: updateTimeStamp, - TimeLockDelta: 10, - MinHTLC: 1, - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, + edgePolicy = &channeldb.ChannelEdgePolicyWithNode{ + ChannelEdgePolicy1: channeldb.ChannelEdgePolicy1{ + SigBytes: testSig.Serialize(), + ChannelID: edge.ChannelID, + LastUpdate: updateTimeStamp, + TimeLockDelta: 10, + MinHTLC: 1, + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + }, } edgePolicy.ChannelFlags = 1 if err := ctx.router.UpdateEdge(edgePolicy); err != nil { diff --git a/rpcserver.go b/rpcserver.go index 6b16e2b686..dec94280b4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5924,7 +5924,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. err = graph.ForEachChannel(func(edgeInfo models.ChannelEdgeInfo, - c1, c2 *channeldb.ChannelEdgePolicy1) error { + c1, c2 *channeldb.ChannelEdgePolicyWithNode) error { // Do not include unannounced channels unless specifically // requested. Unannounced channels include both private channels as @@ -5934,7 +5934,10 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, return nil } - edge, err := marshalDBEdge(edgeInfo, c1, c2) + edge, err := marshalDBEdge( + edgeInfo, &c1.ChannelEdgePolicy1, + &c2.ChannelEdgePolicy1, + ) if err != nil { return err } @@ -6132,7 +6135,9 @@ func (r *rpcServer) GetChanInfo(ctx context.Context, // Convert the database's edge format into the network/RPC edge format // which couples the edge itself along with the directional node // routing policies of each node involved within the channel. - channelEdge, err := marshalDBEdge(edgeInfo, edge1, edge2) + channelEdge, err := marshalDBEdge( + edgeInfo, &edge1.ChannelEdgePolicy1, &edge2.ChannelEdgePolicy1, + ) if err != nil { return nil, err } @@ -6175,7 +6180,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, if err := node.ForEachChannel(graph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, edge models.ChannelEdgeInfo, - c1, c2 *channeldb.ChannelEdgePolicy1) error { + c1, c2 *channeldb.ChannelEdgePolicyWithNode) error { numChannels++ totalCapacity += edge.GetCapacity() @@ -6191,7 +6196,10 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // Convert the database's edge format into the // network/RPC edge format. - channelEdge, err := marshalDBEdge(edge, c1, c2) + channelEdge, err := marshalDBEdge( + edge, &c1.ChannelEdgePolicy1, + &c2.ChannelEdgePolicy1, + ) if err != nil { return err } @@ -6792,7 +6800,8 @@ func (r *rpcServer) FeeReport(ctx context.Context, err = selfNode.ForEachChannel(channelGraph.DB(), nil, func(_ kvdb.Backend, _ kvdb.RTx, chanInfo models.ChannelEdgeInfo, - edgePolicy, _ *channeldb.ChannelEdgePolicy1) error { + edgePolicy, + _ *channeldb.ChannelEdgePolicyWithNode) error { // Self node should always have policies for its // channels. diff --git a/server.go b/server.go index 778e89ee8f..7e3bdf2233 100644 --- a/server.go +++ b/server.go @@ -1235,7 +1235,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Wrap the DeleteChannelEdges method so that the funding manager can // use it without depending on several layers of indirection. deleteAliasEdge := func(scid lnwire.ShortChannelID) ( - *channeldb.ChannelEdgePolicy1, error) { + *channeldb.ChannelEdgePolicyWithNode, error) { info, e1, e2, err := s.graphDB.FetchChannelEdgesByID( scid.ToUint64(), @@ -1254,7 +1254,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, var ourKey [33]byte copy(ourKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) - var ourPolicy *channeldb.ChannelEdgePolicy1 + var ourPolicy *channeldb.ChannelEdgePolicyWithNode if info != nil && info.Node1Bytes() == ourKey { ourPolicy = e1 } else { @@ -3098,7 +3098,7 @@ func (s *server) establishPersistentConnections() error { err = sourceNode.ForEachChannel(s.graphDB.DB(), nil, func( db kvdb.Backend, tx kvdb.RTx, chanInfo models.ChannelEdgeInfo, - policy, _ *channeldb.ChannelEdgePolicy1) error { + policy, _ *channeldb.ChannelEdgePolicyWithNode) error { // If the remote party has announced the channel to us, but we // haven't yet, then we won't have a policy. However, we don't From 4317a97e4c7368ff79ec9bae1bcc5403dae9271b Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 15:20:03 +0200 Subject: [PATCH 28/33] multi: pass MessageSignerKeyring to funding manager Also adds SignMusig2 to the interface. This should be moved to it's own commit. --- funding/manager.go | 24 ++++++------ funding/manager_test.go | 46 +++++++++++----------- keychain/btcwallet.go | 23 +++++++++++ keychain/derivation.go | 17 +++++++++ keychain/signer.go | 22 +++++++++++ lntest/mock/secretkeyring.go | 14 +++++++ lnwallet/rpcwallet/rpcwallet.go | 68 +++++++++++++++++++++++++++++++++ server.go | 8 ++-- 8 files changed, 184 insertions(+), 38 deletions(-) diff --git a/funding/manager.go b/funding/manager.go index c94a597059..61fd469525 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -10,7 +10,6 @@ import ( "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -383,8 +382,7 @@ type Config struct { // // TODO(roasbeef): should instead pass on this responsibility to a // distinct sub-system? - SignMessage func(keyLoc keychain.KeyLocator, - msg []byte, doubleHash bool) (*ecdsa.Signature, error) + MessageSigner keychain.MessageSignerRing // CurrentNodeAnnouncement should return the latest, fully signed node // announcement from the backing Lightning Network node with a fresh @@ -3335,7 +3333,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( func (f *Manager) addToRouterGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, - ourPolicy *channeldb.ChannelEdgePolicy1) error { + ourPolicy *channeldb.ChannelEdgePolicyWithNode) error { chanID := lnwire.NewChanIDFromOutPoint(&completeChan.FundingOutpoint) @@ -3524,8 +3522,7 @@ func (f *Manager) annAfterSixConfs(completeChan *channeldb.OpenChannel, } err = f.addToRouterGraph( - completeChan, &baseScid, nil, - &ourPolicy.ChannelEdgePolicy1, + completeChan, &baseScid, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed to re-add to "+ @@ -3616,8 +3613,7 @@ func (f *Manager) waitForZeroConfChannel(c *channeldb.OpenChannel, // alias since we'll be using the confirmed SCID from now on // regardless if it's public or not. err = f.addToRouterGraph( - c, &confChan.shortChanID, nil, - &ourPolicy.ChannelEdgePolicy1, + c, &confChan.shortChanID, nil, ourPolicy, ) if err != nil { return fmt.Errorf("failed adding confirmed zero-conf "+ @@ -4069,7 +4065,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, remotePubKey *btcec.PublicKey, localFundingKey *keychain.KeyDescriptor, remoteFundingKey *btcec.PublicKey, shortChanID lnwire.ShortChannelID, chanID lnwire.ChannelID, fwdMinHTLC, fwdMaxHTLC lnwire.MilliSatoshi, - ourPolicy *channeldb.ChannelEdgePolicy1, + ourPolicy *channeldb.ChannelEdgePolicyWithNode, chanType channeldb.ChannelType) (*chanAnnouncement, error) { chainHash := *f.cfg.Wallet.Cfg.NetParams.GenesisHash @@ -4208,7 +4204,9 @@ func (f *Manager) newChanAnnouncement(localPubKey, if err != nil { return nil, err } - sig, err := f.cfg.SignMessage(f.cfg.IDKeyLoc, chanUpdateMsg, true) + sig, err := f.cfg.MessageSigner.SignMessage( + f.cfg.IDKeyLoc, chanUpdateMsg, true, + ) if err != nil { return nil, errors.Errorf("unable to generate channel "+ "update announcement signature: %v", err) @@ -4230,12 +4228,14 @@ func (f *Manager) newChanAnnouncement(localPubKey, if err != nil { return nil, err } - nodeSig, err := f.cfg.SignMessage(f.cfg.IDKeyLoc, chanAnnMsg, true) + nodeSig, err := f.cfg.MessageSigner.SignMessage( + f.cfg.IDKeyLoc, chanAnnMsg, true, + ) if err != nil { return nil, errors.Errorf("unable to generate node "+ "signature for channel announcement: %v", err) } - bitcoinSig, err := f.cfg.SignMessage( + bitcoinSig, err := f.cfg.MessageSigner.SignMessage( localFundingKey.KeyLocator, chanAnnMsg, true, ) if err != nil { diff --git a/funding/manager_test.go b/funding/manager_test.go index 0357d877b4..b3da1c75fa 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -445,17 +445,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, chainedAcceptor := acpt.NewChainedAcceptor() fundingCfg := Config{ - IDKey: privKey.PubKey(), - IDKeyLoc: testKeyLoc, - Wallet: lnw, - Notifier: chainNotifier, - ChannelDB: cdb, - FeeEstimator: estimator, - SignMessage: func(_ keychain.KeyLocator, - _ []byte, _ bool) (*ecdsa.Signature, error) { - - return testSig, nil - }, + IDKey: privKey.PubKey(), + IDKeyLoc: testKeyLoc, + Wallet: lnw, + Notifier: chainNotifier, + ChannelDB: cdb, + FeeEstimator: estimator, + MessageSigner: &mockSigner{}, SendAnnouncement: func(msg lnwire.Message, _ ...discovery.OptionalMsgField) chan error { @@ -608,17 +604,13 @@ func recreateAliceFundingManager(t *testing.T, alice *testNode) { chainedAcceptor := acpt.NewChainedAcceptor() f, err := NewFundingManager(Config{ - IDKey: oldCfg.IDKey, - IDKeyLoc: oldCfg.IDKeyLoc, - Wallet: oldCfg.Wallet, - Notifier: oldCfg.Notifier, - ChannelDB: oldCfg.ChannelDB, - FeeEstimator: oldCfg.FeeEstimator, - SignMessage: func(_ keychain.KeyLocator, - _ []byte, _ bool) (*ecdsa.Signature, error) { - - return testSig, nil - }, + IDKey: oldCfg.IDKey, + IDKeyLoc: oldCfg.IDKeyLoc, + Wallet: oldCfg.Wallet, + Notifier: oldCfg.Notifier, + ChannelDB: oldCfg.ChannelDB, + FeeEstimator: oldCfg.FeeEstimator, + MessageSigner: &mockSigner{}, SendAnnouncement: func(msg lnwire.Message, _ ...discovery.OptionalMsgField) chan error { @@ -4955,3 +4947,13 @@ func TestFundingManagerCoinbase(t *testing.T) { // channel. assertHandleChannelReady(t, alice, bob) } + +type mockSigner struct { + keychain.MessageSignerRing +} + +func (s *mockSigner) SignMessage(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool) (*ecdsa.Signature, error) { + + return testSig, nil +} diff --git a/keychain/btcwallet.go b/keychain/btcwallet.go index 4efb9fd614..b11696e6b2 100644 --- a/keychain/btcwallet.go +++ b/keychain/btcwallet.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcwallet/waddrmgr" @@ -427,6 +428,28 @@ func (b *BtcWalletKeyRing) SignMessage(keyLoc KeyLocator, return ecdsa.Sign(privKey, digest), nil } +// SignMuSig2 generates a MuSig2 partial signature given the passed key set, +// secret nonce, public nonce, and private keys +// +// NOTE: This is part of the keychain.MessageSignerRing interface. +func (b *BtcWalletKeyRing) SignMuSig2(secNonce [musig2.SecNonceSize]byte, + keyLoc KeyLocator, _ [][musig2.PubNonceSize]byte, + combinedNonce [musig2.PubNonceSize]byte, pubKeys []*btcec.PublicKey, + msg [32]byte, signOpts ...musig2.SignOption) (*musig2.PartialSignature, + error) { + + privKey, err := b.DerivePrivKey(KeyDescriptor{ + KeyLocator: keyLoc, + }) + if err != nil { + return nil, err + } + + return musig2.Sign( + secNonce, privKey, combinedNonce, pubKeys, msg, signOpts..., + ) +} + // SignMessageCompact signs the given message, single or double SHA256 hashing // it first, with the private key described in the key locator and returns // the signature in the compact, public key recoverable format. diff --git a/keychain/derivation.go b/keychain/derivation.go index 21996c6509..024eda3413 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" ) const ( @@ -241,6 +242,14 @@ type MessageSignerRing interface { SignMessageSchnorr(keyLoc KeyLocator, msg []byte, doubleHash bool, taprootTweak []byte) (*schnorr.Signature, error) + + // SignMuSig2 generates a MuSig2 partial signature given the passed key + // set, secret nonce, public nonce, and private keys + SignMuSig2(secNonce [musig2.SecNonceSize]byte, + keyLoc KeyLocator, otherNonces [][musig2.PubNonceSize]byte, + combinedNonce [musig2.PubNonceSize]byte, + pubKeys []*btcec.PublicKey, msg [32]byte, + opts ...musig2.SignOption) (*musig2.PartialSignature, error) } // SingleKeyMessageSigner is an abstraction interface that hides the @@ -262,6 +271,14 @@ type SingleKeyMessageSigner interface { // hashing it first, with the wrapped private key and returns the // signature in the compact, public key recoverable format. SignMessageCompact(message []byte, doubleHash bool) ([]byte, error) + + // SignMuSig2 generates a MuSig2 partial signature given the passed key + // set, secret nonce, public nonce, and private keys. + SignMuSig2(secNonce [musig2.SecNonceSize]byte, + keyLoc KeyLocator, otherNonces [][musig2.PubNonceSize]byte, + combinedNonce [musig2.PubNonceSize]byte, + pubKeys []*btcec.PublicKey, msg [32]byte, + opts ...musig2.SignOption) (*musig2.PartialSignature, error) } // ECDHRing is an interface that abstracts away basic low-level ECDH shared key diff --git a/keychain/signer.go b/keychain/signer.go index 9605e72ec1..d105af9838 100644 --- a/keychain/signer.go +++ b/keychain/signer.go @@ -3,6 +3,7 @@ package keychain import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -42,6 +43,17 @@ func (p *PubKeyMessageSigner) SignMessageCompact(msg []byte, return p.digestSigner.SignMessageCompact(p.keyLoc, msg, doubleHash) } +func (p *PubKeyMessageSigner) SignMuSig2(secNonce [97]byte, keyLoc KeyLocator, + otherNonces [][66]byte, combinedNonce [66]byte, + pubKeys []*btcec.PublicKey, msg [32]byte, opts ...musig2.SignOption) ( + *musig2.PartialSignature, error) { + + return p.digestSigner.SignMuSig2( + secNonce, keyLoc, otherNonces, combinedNonce, pubKeys, msg, + opts..., + ) +} + func NewPrivKeyMessageSigner(privKey *btcec.PrivateKey, keyLoc KeyLocator) *PrivKeyMessageSigner { @@ -88,5 +100,15 @@ func (p *PrivKeyMessageSigner) SignMessageCompact(msg []byte, return ecdsa.SignCompact(p.privKey, digest, true) } +func (p *PrivKeyMessageSigner) SignMuSig2(secNonce [97]byte, _ KeyLocator, + _ [][66]byte, combinedNonce [66]byte, + pubKeys []*btcec.PublicKey, msg [32]byte, opts ...musig2.SignOption) ( + *musig2.PartialSignature, error) { + + return musig2.Sign( + secNonce, p.privKey, combinedNonce, pubKeys, msg, opts..., + ) +} + var _ SingleKeyMessageSigner = (*PubKeyMessageSigner)(nil) var _ SingleKeyMessageSigner = (*PrivKeyMessageSigner)(nil) diff --git a/lntest/mock/secretkeyring.go b/lntest/mock/secretkeyring.go index 770b0e4d35..062cee1042 100644 --- a/lntest/mock/secretkeyring.go +++ b/lntest/mock/secretkeyring.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/keychain" @@ -59,6 +60,19 @@ func (s *SecretKeyRing) SignMessage(_ keychain.KeyLocator, return ecdsa.Sign(s.RootKey, digest), nil } +// SignMuSig2 generates a MuSig2 partial signature given the passed key set, +// secret nonce, public nonce, and private keys. +func (s *SecretKeyRing) SignMuSig2(secNonce [musig2.SecNonceSize]byte, + _ keychain.KeyLocator, _ [][musig2.PubNonceSize]byte, + combinedNonce [musig2.PubNonceSize]byte, pubKeys []*btcec.PublicKey, + msg [32]byte, opts ...musig2.SignOption) (*musig2.PartialSignature, + error) { + + return musig2.Sign( + secNonce, s.RootKey, combinedNonce, pubKeys, msg, opts..., + ) +} + // SignMessageCompact signs the passed message. func (s *SecretKeyRing) SignMessageCompact(_ keychain.KeyLocator, msg []byte, doubleHash bool) ([]byte, error) { diff --git a/lnwallet/rpcwallet/rpcwallet.go b/lnwallet/rpcwallet/rpcwallet.go index 69e1f25b76..53719d647c 100644 --- a/lnwallet/rpcwallet/rpcwallet.go +++ b/lnwallet/rpcwallet/rpcwallet.go @@ -470,6 +470,74 @@ func (r *RPCKeyRing) SignMessage(keyLoc keychain.KeyLocator, return ecdsaSig, nil } +// SignMuSig2 generates a MuSig2 partial signature given the passed key set, +// secret nonce, public nonce, and private keys. +// +// NOTE: This method is part of the keychain.MessageSignerRing interface. +func (r *RPCKeyRing) SignMuSig2(secNonce [musig2.SecNonceSize]byte, + keyLoc keychain.KeyLocator, otherNonces [][musig2.PubNonceSize]byte, + _ [musig2.PubNonceSize]byte, pubKeys []*btcec.PublicKey, msg [32]byte, + _ ...musig2.SignOption) (*musig2.PartialSignature, error) { + + ctxt, cancel := context.WithTimeout(context.Background(), r.rpcTimeout) + defer cancel() + + serialisedPubKeys := make([][]byte, len(pubKeys)) + for i, pub := range pubKeys { + serialisedPubKeys[i] = pub.SerializeCompressed() + } + + otherSignerNonces := make([][]byte, len(otherNonces)) + for i, nonce := range otherNonces { + otherSignerNonces[i] = nonce[:] + } + + musigVersion := signrpc.MuSig2Version_MUSIG2_VERSION_V100RC2 + + // TODO: what to do with the musig sign options here? + + resp, err := r.signerClient.MuSig2CreateSession( + ctxt, &signrpc.MuSig2SessionRequest{ + KeyLoc: &signrpc.KeyLocator{ + KeyFamily: int32(keyLoc.Family), + KeyIndex: int32(keyLoc.Index), + }, + OtherSignerPublicNonces: otherSignerNonces, + PregeneratedLocalNonce: secNonce[:], + AllSignerPubkeys: serialisedPubKeys, + Version: musigVersion, + }, + ) + if err != nil { + considerShutdown(err) + return nil, fmt.Errorf("error signing message in remote "+ + "signer instance: %v", err) + } + + signResp, err := r.signerClient.MuSig2Sign( + ctxt, &signrpc.MuSig2SignRequest{ + SessionId: resp.SessionId, + MessageDigest: msg[:], + Cleanup: true, + }, + ) + if err != nil { + considerShutdown(err) + return nil, fmt.Errorf("error signing message in remote "+ + "signer instance: %v", err) + } + + partialSig, err := input.DeserializePartialSignature( + signResp.LocalPartialSignature, + ) + if err != nil { + return nil, fmt.Errorf("error parsing partial signature from "+ + "remote signer: %v", err) + } + + return partialSig, nil +} + // SignMessageCompact signs the given message, single or double SHA256 hashing // it first, with the private key described in the key locator and returns the // signature in the compact, public key recoverable format. diff --git a/server.go b/server.go index 7e3bdf2233..e1e9a87558 100644 --- a/server.go +++ b/server.go @@ -1292,10 +1292,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr, UpdateLabel: func(hash chainhash.Hash, label string) error { return cc.Wallet.LabelTransaction(hash, label, true) }, - Notifier: cc.ChainNotifier, - ChannelDB: s.chanStateDB, - FeeEstimator: cc.FeeEstimator, - SignMessage: cc.MsgSigner.SignMessage, + Notifier: cc.ChainNotifier, + ChannelDB: s.chanStateDB, + FeeEstimator: cc.FeeEstimator, + MessageSigner: cc.KeyRing, CurrentNodeAnnouncement: func() (lnwire.NodeAnnouncement, error) { From dd6d924d50974861e76834496dc38d8c34afae2b Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 16:36:54 +0200 Subject: [PATCH 29/33] multi: add SignSchnorr to MessageSignerKeyring --- discovery/gossiper.go | 3 +- keychain/derivation.go | 3 ++ keychain/signer.go | 28 +++++++++++++++++ lnrpc/invoicesrpc/addinvoice.go | 2 +- lntest/mock/signer.go | 21 +++++++++++++ netann/chan_status_manager.go | 3 +- netann/channel_update.go | 6 ++-- netann/channel_update_test.go | 7 +++-- netann/node_signer.go | 53 ++++++++++++++++++++++++++++++--- rpcserver.go | 2 +- 10 files changed, 112 insertions(+), 16 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 941964865c..7e66369d3b 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -24,7 +24,6 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnutils" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/netann" @@ -254,7 +253,7 @@ type Config struct { // // TODO(roasbeef): extract ann crafting + sign from fundingMgr into // here? - AnnSigner lnwallet.MessageSigner + AnnSigner keychain.MessageSignerRing // NumActiveSyncers is the number of peers for which we should have // active syncers with. After reaching NumActiveSyncers, any future diff --git a/keychain/derivation.go b/keychain/derivation.go index 024eda3413..69e557818d 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -279,6 +279,9 @@ type SingleKeyMessageSigner interface { combinedNonce [musig2.PubNonceSize]byte, pubKeys []*btcec.PublicKey, msg [32]byte, opts ...musig2.SignOption) (*musig2.PartialSignature, error) + + SignMessageSchnorr(keyLoc KeyLocator, msg []byte, doubleHash bool, + taprootTweak []byte) (*schnorr.Signature, error) } // ECDHRing is an interface that abstracts away basic low-level ECDH shared key diff --git a/keychain/signer.go b/keychain/signer.go index d105af9838..9bdf3f9aeb 100644 --- a/keychain/signer.go +++ b/keychain/signer.go @@ -3,8 +3,10 @@ package keychain import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" ) func NewPubKeyMessageSigner(pubKey *btcec.PublicKey, keyLoc KeyLocator, @@ -54,6 +56,14 @@ func (p *PubKeyMessageSigner) SignMuSig2(secNonce [97]byte, keyLoc KeyLocator, ) } +func (p *PubKeyMessageSigner) SignMessageSchnorr(keyLoc KeyLocator, msg []byte, + doubleHash bool, taprootTweak []byte) (*schnorr.Signature, error) { + + return p.digestSigner.SignMessageSchnorr( + keyLoc, msg, doubleHash, taprootTweak, + ) +} + func NewPrivKeyMessageSigner(privKey *btcec.PrivateKey, keyLoc KeyLocator) *PrivKeyMessageSigner { @@ -110,5 +120,23 @@ func (p *PrivKeyMessageSigner) SignMuSig2(secNonce [97]byte, _ KeyLocator, ) } +func (p *PrivKeyMessageSigner) SignMessageSchnorr(_ KeyLocator, msg []byte, + doubleHash bool, taprootTweak []byte) (*schnorr.Signature, error) { + + var digest []byte + if doubleHash { + digest = chainhash.DoubleHashB(msg) + } else { + digest = chainhash.HashB(msg) + } + + privKey := p.privKey + if len(taprootTweak) > 0 { + privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak) + } + + return schnorr.Sign(privKey, digest) +} + var _ SingleKeyMessageSigner = (*PubKeyMessageSigner)(nil) var _ SingleKeyMessageSigner = (*PrivKeyMessageSigner)(nil) diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 430b2b4276..e04096bfe3 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -450,7 +450,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, payReqString, err := payReq.Encode(zpay32.MessageSigner{ SignCompact: func(msg []byte) ([]byte, error) { - return cfg.NodeSigner.SignMessageCompact(msg, false) + return cfg.NodeSigner.SignMsgCompact(msg, false) }, }) if err != nil { diff --git a/lntest/mock/signer.go b/lntest/mock/signer.go index 1d30204ea9..866115aab7 100644 --- a/lntest/mock/signer.go +++ b/lntest/mock/signer.go @@ -109,6 +109,27 @@ type SingleSigner struct { *input.MusigSessionManager } +func (s *SingleSigner) SignMessageCompact(keyLoc keychain.KeyLocator, + msg []byte, doubleHash bool) ([]byte, error) { + + return nil, nil +} + +func (s *SingleSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, + msg []byte, doubleHash bool, taprootTweak []byte) (*schnorr.Signature, + error) { + + return nil, nil +} + +func (s *SingleSigner) SignMuSig2(secNonce [97]byte, + keyLoc keychain.KeyLocator, otherNonces [][66]byte, + combinedNonce [66]byte, pubKeys []*btcec.PublicKey, msg [32]byte, + opts ...musig2.SignOption) (*musig2.PartialSignature, error) { + + return nil, nil +} + func NewSingleSigner(privkey *btcec.PrivateKey) *SingleSigner { signer := &SingleSigner{ Privkey: privkey, diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index 1d3aa6994b..1dc4ba7e4e 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -9,7 +9,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -49,7 +48,7 @@ type ChanStatusConfig struct { OurKeyLoc keychain.KeyLocator // MessageSigner signs messages that validate under OurPubKey. - MessageSigner lnwallet.MessageSigner + MessageSigner keychain.MessageSignerRing // IsChannelActive checks whether the channel identified by the provided // ChannelID is considered active. This should only return true if the diff --git a/netann/channel_update.go b/netann/channel_update.go index 3e0b971cd5..dcd32ac5e1 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -10,7 +10,6 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -58,8 +57,9 @@ func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { // monotonically increase from the prior. // // NOTE: This method modifies the given update. -func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, - update *lnwire.ChannelUpdate1, mods ...ChannelUpdateModifier) error { +func SignChannelUpdate(signer keychain.MessageSignerRing, + keyLoc keychain.KeyLocator, update *lnwire.ChannelUpdate1, + mods ...ChannelUpdateModifier) error { // Apply the requested changes to the channel update. for _, modifier := range mods { diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go index a32d96d88b..6836ebb957 100644 --- a/netann/channel_update_test.go +++ b/netann/channel_update_test.go @@ -8,7 +8,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" "github.com/lightningnetwork/lnd/routing" @@ -16,6 +15,8 @@ import ( type mockSigner struct { err error + + keychain.MessageSignerRing } func (m *mockSigner) SignMessage(_ keychain.KeyLocator, @@ -28,7 +29,7 @@ func (m *mockSigner) SignMessage(_ keychain.KeyLocator, return nil, nil } -var _ lnwallet.MessageSigner = (*mockSigner)(nil) +var _ keychain.MessageSignerRing = (*mockSigner)(nil) var ( privKey, _ = btcec.NewPrivateKey() @@ -44,7 +45,7 @@ type updateDisableTest struct { startEnabled bool disable bool startTime time.Time - signer lnwallet.MessageSigner + signer keychain.MessageSignerRing expErr error } diff --git a/netann/node_signer.go b/netann/node_signer.go index e0b439b14d..29a86e6b32 100644 --- a/netann/node_signer.go +++ b/netann/node_signer.go @@ -3,9 +3,11 @@ package netann import ( "fmt" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwallet" ) // NodeSigner is an implementation of the MessageSigner interface backed by the @@ -43,15 +45,58 @@ func (n *NodeSigner) SignMessage(keyLoc keychain.KeyLocator, return sig, nil } -// SignMessageCompact signs a single or double sha256 digest of the msg +// SignMsgCompact signs a single or double sha256 digest of the msg // parameter under the resident node's private key. The returned signature is a // pubkey-recoverable signature. -func (n *NodeSigner) SignMessageCompact(msg []byte, doubleHash bool) ([]byte, +func (n *NodeSigner) SignMsgCompact(msg []byte, doubleHash bool) ([]byte, error) { return n.keySigner.SignMessageCompact(msg, doubleHash) } +func (n *NodeSigner) SignMessageCompact(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool) ([]byte, error) { + + // If this isn't our identity public key, then we'll exit early with an + // error as we can't sign with this key. + if keyLoc != n.keySigner.KeyLocator() { + return nil, fmt.Errorf("unknown public key locator") + } + + return n.keySigner.SignMessageCompact(msg, doubleHash) +} + +func (n *NodeSigner) SignMessageSchnorr(keyLoc keychain.KeyLocator, msg []byte, + doubleHash bool, taprootTweak []byte) (*schnorr.Signature, error) { + + // If this isn't our identity public key, then we'll exit early with an + // error as we can't sign with this key. + if keyLoc != n.keySigner.KeyLocator() { + return nil, fmt.Errorf("unknown public key locator") + } + + return n.keySigner.SignMessageSchnorr( + keyLoc, msg, doubleHash, taprootTweak, + ) +} + +func (n *NodeSigner) SignMuSig2(secNonce [97]byte, keyLoc keychain.KeyLocator, + otherNonces [][66]byte, combinedNonce [66]byte, + pubKeys []*btcec.PublicKey, msg [32]byte, opts ...musig2.SignOption) ( + *musig2.PartialSignature, error) { + + // If this isn't our identity public key, then we'll exit early with an + // error as we can't sign with this key. + if keyLoc != n.keySigner.KeyLocator() { + return nil, fmt.Errorf("unknown public key locator") + } + + return n.keySigner.SignMuSig2( + secNonce, keyLoc, otherNonces, combinedNonce, pubKeys, msg, + opts..., + ) +} + // A compile time check to ensure that NodeSigner implements the MessageSigner // interface. -var _ lnwallet.MessageSigner = (*NodeSigner)(nil) +var _ keychain.MessageSignerRing = (*NodeSigner)(nil) diff --git a/rpcserver.go b/rpcserver.go index dec94280b4..c51bdff418 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1573,7 +1573,7 @@ func (r *rpcServer) SignMessage(_ context.Context, } in.Msg = append(signedMsgPrefix, in.Msg...) - sigBytes, err := r.server.nodeSigner.SignMessageCompact( + sigBytes, err := r.server.nodeSigner.SignMsgCompact( in.Msg, !in.SingleHash, ) if err != nil { From 33630f2c84ec5bb79fa2b299d924f7a286d8b96c Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 19:37:13 +0200 Subject: [PATCH 30/33] lnwire: let ChannelUpdate2 implement ChannelUpdate --- lnwire/channel_update_2.go | 87 +++++++++++++++++++++++++++++++++++--- lnwire/msg_hash.go | 20 ++++++++- 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 88452dd9a5..ffcfc11b58 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/tlv" ) @@ -120,6 +121,58 @@ type ChannelUpdate2 struct { ExtraOpaqueData ExtraOpaqueData } +func (c *ChannelUpdate2) SCID() ShortChannelID { + return c.ShortChannelID +} + +func (c *ChannelUpdate2) IsNode1() bool { + return c.Direction == false +} + +func (c *ChannelUpdate2) SetDisabled(b bool) { + c.DisabledFlags |= ChanUpdateDisableIncoming + c.DisabledFlags |= ChanUpdateDisableOutgoing +} + +func (c *ChannelUpdate2) IsDisabled() bool { + return !c.DisabledFlags.IsEnabled() +} + +func (c *ChannelUpdate2) GetChainHash() chainhash.Hash { + return c.ChainHash +} + +func (c *ChannelUpdate2) SetSig(signature input.Signature) error { + sig, err := NewSigFromSignature(signature) + if err != nil { + return err + } + + c.Signature = sig + + return nil +} + +func (c *ChannelUpdate2) SetSCID(scid ShortChannelID) { + c.ShortChannelID = scid +} + +func (c *ChannelUpdate2) GetTimeLock() uint16 { + return c.CLTVExpiryDelta +} + +func (c *ChannelUpdate2) GetBaseFee() MilliSatoshi { + return MilliSatoshi(c.FeeBaseMsat) +} + +func (c *ChannelUpdate2) GetFeeRate() MilliSatoshi { + return MilliSatoshi(c.FeeProportionalMillionths) +} + +func (c *ChannelUpdate2) GetSignature() Sig { + return c.Signature +} + // Decode deserializes a serialized AnnounceSignatures stored in the passed // io.Reader observing the specified protocol version. // @@ -245,10 +298,30 @@ func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { return err } - var records []tlv.Record + _, err = c.DataToSign() + if err != nil { + return err + } + return WriteBytes(w, c.ExtraOpaqueData) +} + +// DigestToSignNoHash computes the digest of the message to be signed. +func (c *ChannelUpdate2) DigestToSignNoHash() ([]byte, error) { + data, err := c.DataToSign() + if err != nil { + return nil, err + } + + return MsgHashPreFinalHash( + "channel_announcement_2", "announcement_signature", data, + ), nil +} + +func (c *ChannelUpdate2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. + var records []tlv.Record if !c.ChainHash.IsEqual(chaincfg.MainNetParams.GenesisHash) { chainHash := [32]byte(c.ChainHash) records = append(records, tlv.MakePrimitiveRecord( @@ -318,12 +391,12 @@ func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { )) } - err = EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) + err := EncodeMessageExtraDataFromRecords(&c.ExtraOpaqueData, records...) if err != nil { - return err + return nil, err } - return WriteBytes(w, c.ExtraOpaqueData) + return c.ExtraOpaqueData, nil } // MsgType returns the integer uniquely identifying this message type on the @@ -334,9 +407,9 @@ func (c *ChannelUpdate2) MsgType() MessageType { return MsgChannelUpdate2 } -// A compile time check to ensure ChannelUpdate2 implements the lnwire.Message -// interface. -var _ Message = (*ChannelUpdate2)(nil) +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.ChannelUpdate interface. +var _ ChannelUpdate = (*ChannelUpdate2)(nil) // ChanUpdateDisableFlags is a bit vector that can be used to indicate various // reasons for the channel being marked as disabled. diff --git a/lnwire/msg_hash.go b/lnwire/msg_hash.go index a3f05b8db5..598e1fdf7a 100644 --- a/lnwire/msg_hash.go +++ b/lnwire/msg_hash.go @@ -1,6 +1,11 @@ package lnwire -import "github.com/btcsuite/btcd/chaincfg/chainhash" +import ( + "bytes" + "crypto/sha256" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) const MsgHashTag = "lightning" @@ -11,3 +16,16 @@ func MsgHash(msgName, fieldName string, msg []byte) *chainhash.Hash { return chainhash.TaggedHash(tag, msg) } + +func MsgHashPreFinalHash(msgName, fieldName string, msg []byte) []byte { + tag := []byte(MsgHashTag) + shaTag := sha256.Sum256(tag) + + var b bytes.Buffer + + b.Write(shaTag[:]) + b.Write(shaTag[:]) + b.Write(msg) + + return b.Bytes() +} From 548942821ca371cdd233bbc5f96d3c68acdfb49a Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Thu, 19 Oct 2023 19:37:35 +0200 Subject: [PATCH 31/33] netann: update to use ChannelUpdate interface --- netann/channel_update.go | 84 ++++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 29 deletions(-) diff --git a/netann/channel_update.go b/netann/channel_update.go index dcd32ac5e1..43ce8ecf32 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -19,36 +19,35 @@ var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate1" // ChannelUpdateModifier is a closure that makes in-place modifications to an // lnwire.ChannelUpdate1. -type ChannelUpdateModifier func(*lnwire.ChannelUpdate1) +type ChannelUpdateModifier func(lnwire.ChannelUpdate) // ChanUpdSetDisable is a functional option that sets the disabled channel flag // if disabled is true, and clears the bit otherwise. func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { - return func(update *lnwire.ChannelUpdate1) { - if disabled { - // Set the bit responsible for marking a channel as - // disabled. - update.ChannelFlags |= lnwire.ChanUpdateDisabled - } else { - // Clear the bit responsible for marking a channel as - // disabled. - update.ChannelFlags &= ^lnwire.ChanUpdateDisabled - } + return func(update lnwire.ChannelUpdate) { + update.SetDisabled(disabled) } } // ChanUpdSetTimestamp is a functional option that sets the timestamp of the // update to the current time, or increments it if the timestamp is already in // the future. -func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { - newTimestamp := uint32(time.Now().Unix()) - if newTimestamp <= update.Timestamp { - // Increment the prior value to ensure the timestamp - // monotonically increases, otherwise the update won't - // propagate. - newTimestamp = update.Timestamp + 1 +func ChanUpdSetTimestamp(update lnwire.ChannelUpdate) { + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + newTimestamp := uint32(time.Now().Unix()) + if newTimestamp <= upd.Timestamp { + // Increment the prior value to ensure the timestamp + // monotonically increases, otherwise the update won't + // propagate. + newTimestamp = upd.Timestamp + 1 + } + upd.Timestamp = newTimestamp + + default: + log.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) } - update.Timestamp = newTimestamp } // SignChannelUpdate applies the given modifiers to the passed @@ -58,7 +57,7 @@ func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { // // NOTE: This method modifies the given update. func SignChannelUpdate(signer keychain.MessageSignerRing, - keyLoc keychain.KeyLocator, update *lnwire.ChannelUpdate1, + keyLoc keychain.KeyLocator, update lnwire.ChannelUpdate, mods ...ChannelUpdateModifier) error { // Apply the requested changes to the channel update. @@ -66,16 +65,43 @@ func SignChannelUpdate(signer keychain.MessageSignerRing, modifier(update) } - // Create the DER-encoded ECDSA signature over the message digest. - sig, err := SignAnnouncement(signer, keyLoc, update) - if err != nil { - return err - } + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + data, err := upd.DataToSign() + if err != nil { + return err + } - // Parse the DER-encoded signature into a fixed-size 64-byte array. - update.Signature, err = lnwire.NewSigFromSignature(sig) - if err != nil { - return err + sig, err := signer.SignMessage(keyLoc, data, true) + if err != nil { + return err + } + + // Parse the DER-encoded signature into a fixed-size 64-byte + // array. + upd.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + + case *lnwire.ChannelUpdate2: + data, err := upd.DigestToSignNoHash() + if err != nil { + return err + } + + sig, err := signer.SignMessageSchnorr(keyLoc, data, false, nil) + if err != nil { + return err + } + + upd.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + default: + return fmt.Errorf("unhandled implementaion of "+ + "ChannelUpdate: %T", update) } return nil From 7adbe6afb159ded9bfbbc83761b09eb4b7ea0464 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 20 Oct 2023 10:27:52 +0200 Subject: [PATCH 32/33] channeldb: prep for new ChanEdgePolicy encoding --- channeldb/graph.go | 59 ++++++++++++++++++++++++++-------- channeldb/models/interfaces.go | 3 ++ 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 992d0998fe..17d490febb 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -173,6 +173,13 @@ const ( // with either 0x02 or 0x03 due to the fact that the encoding would // start with a node's compressed public key. chanEdgeNewEncodingPrefix = 0xff + + // chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge + // policy encoding to signal that the new style encoding which is + // prefixed with a type byte is being used instead of the legacy + // encoding which would start with 0x02 due to the fact that the + // encoding would start with a DER encoded ecdsa signature. + chanEdgePolicyNewEncodingPrefix = 0xff ) // ChannelGraph is a persistent, on-disk graph representation of the Lightning @@ -2601,10 +2608,6 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicyWithNode, if edgeIndex == nil { return false, ErrEdgeNotFound } - nodes, err := tx.CreateTopLevelBucket(nodeBucket) - if err != nil { - return false, err - } // Create the channelID key be converting the channel ID // integer into a byte slice. @@ -2644,7 +2647,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicyWithNode, // Finally, with the direction of the edge being updated // identified, we update the on-disk edge representation. err = putChanEdgePolicy( - edges, nodes, &edge.ChannelEdgePolicy1, fromNode, toNode, + edges, &edge.ChannelEdgePolicy1, fromNode, toNode, ) if err != nil { return false, err @@ -3676,6 +3679,8 @@ type ChannelEdgePolicy1 struct { ExtraOpaqueData []byte } +var _ models.ChannelEdgePolicy = (*ChannelEdgePolicy1)(nil) + // Signature is a channel announcement signature, which is needed for proper // edge policy announcement. // @@ -4784,9 +4789,15 @@ func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { return &edgeInfo, nil } -func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy1, +func putChanEdgePolicy(edges kvdb.RwBucket, edgePolicy models.ChannelEdgePolicy, from, to []byte) error { + edge, ok := edgePolicy.(*ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgePolicy: %T", edgePolicy) + } + var edgeKey [33 + 8]byte copy(edgeKey[:], from) byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) @@ -4824,8 +4835,8 @@ func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy1, // the channel ID and update time to delete the entry. // TODO(halseth): get rid of these invalid policies in a // migration. - oldEdgePolicy, err := deserializeChanEdgePolicy( - bytes.NewReader(edgeBytes), nodes, + oldEdgePolicy, _, err := deserializeChanEdgePolicyRaw( + bytes.NewReader(edgeBytes), ) if err != nil && err != ErrEdgePolicyOptionalFieldNotFound { return err @@ -5066,12 +5077,34 @@ func deserializeChanEdgePolicy(r io.Reader, return &policy, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, []byte, +func deserializeChanEdgePolicyRaw(reader io.Reader) (*ChannelEdgePolicy1, + []byte, error) { + + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, nil, err + } + + if firstByte[0] != chanEdgePolicyNewEncodingPrefix { + return deserializeChanEdgePolicy1Raw(r) + } + + return nil, nil, fmt.Errorf("unknown channel edge policy encoding "+ + "type: %x", firstByte[0]) +} + +func deserializeChanEdgePolicy1Raw(r io.Reader) (*ChannelEdgePolicy1, []byte, error) { - edge := &ChannelEdgePolicy1{} + var ( + edge ChannelEdgePolicy1 + err error + ) - var err error edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") if err != nil { return nil, nil, err @@ -5141,7 +5174,7 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, []byte, // stored before this field was validated. We'll return the // edge along with an error. if len(opq) < 8 { - return edge, pub[:], ErrEdgePolicyOptionalFieldNotFound + return &edge, pub[:], ErrEdgePolicyOptionalFieldNotFound } maxHtlc := byteOrder.Uint64(opq[:8]) @@ -5151,7 +5184,7 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*ChannelEdgePolicy1, []byte, edge.ExtraOpaqueData = opq[8:] } - return edge, pub[:], nil + return &edge, pub[:], nil } const ( diff --git a/channeldb/models/interfaces.go b/channeldb/models/interfaces.go index 4a3d82fc98..8c4953d1b1 100644 --- a/channeldb/models/interfaces.go +++ b/channeldb/models/interfaces.go @@ -50,3 +50,6 @@ type ChannelEdgeInfo interface { //nolint:interfacebloat type ChannelAuthProof interface { } + +type ChannelEdgePolicy interface { +} From 2b691cd9727fa7d0c50dbd436610f4049ebbfbe0 Mon Sep 17 00:00:00 2001 From: Elle Mouton <elle.mouton@gmail.com> Date: Fri, 20 Oct 2023 11:35:25 +0200 Subject: [PATCH 33/33] channeldb: ChannelEdgePolicy2 encoding --- channeldb/graph.go | 236 +++++++++++++++++++++++++++++---- channeldb/graph_test.go | 4 +- channeldb/models/interfaces.go | 4 + 3 files changed, 217 insertions(+), 27 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 17d490febb..fc645ddb16 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -107,6 +107,9 @@ var ( // maps: updateTime || chanID -> nil edgeUpdateIndexBucket = []byte("edge-update-index") + // maps: blockHeight || chanID -> nil + edgeUpdateIndex2Bucket = []byte("edge-update-index-2") + // channelPointBucket maps a channel's full outpoint (txid:index) to // its short 8-byte channel ID. This bucket resides within the // edgeBucket above, and can be used to quickly remove an edge due to @@ -323,11 +326,17 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } + e, ok := edge.(*ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected *ChannelEdgePolicy1, "+ + "got: %T", edge) + } + var pub [33]byte copy(pub[:], pubKey) channelMap[key] = &ChannelEdgePolicyWithNode{ - ChannelEdgePolicy1: *edge, + ChannelEdgePolicy1: *e, Node: &LightningNode{ PubKeyBytes: pub, }, @@ -3617,6 +3626,13 @@ type ChannelEdgePolicyWithNode struct { Node *LightningNode } +type edgePolicyEncodingInfo interface { + updateBucketKey() []byte + updateKey() []byte + serialize(w io.Writer, toNode []byte) error + typeByte() (edgePolicyEncodingType, bool) +} + // ChannelEdgePolicy1 represents a *directed* edge within the channel graph. For // each channel in the database, there are two distinct edges: one for each // possible direction of travel along the channel. The edges themselves hold @@ -3679,7 +3695,37 @@ type ChannelEdgePolicy1 struct { ExtraOpaqueData []byte } +func (c *ChannelEdgePolicy1) typeByte() (edgePolicyEncodingType, bool) { + return 0, false +} + +func (c *ChannelEdgePolicy1) IsNode1() bool { + return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 +} + +func (c *ChannelEdgePolicy1) SCID() lnwire.ShortChannelID { + return lnwire.NewShortChanIDFromInt(c.ChannelID) +} + +func (c *ChannelEdgePolicy1) serialize(w io.Writer, toNode []byte) error { + return serializeChanEdgePolicy1(w, c, toNode) +} + +func (c *ChannelEdgePolicy1) updateBucketKey() []byte { + return edgeUpdateIndexBucket +} + +func (c *ChannelEdgePolicy1) updateKey() []byte { + updateUnix := uint64(c.LastUpdate.Unix()) + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + byteOrder.PutUint64(indexKey[8:], c.ChannelID) + + return indexKey[:] +} + var _ models.ChannelEdgePolicy = (*ChannelEdgePolicy1)(nil) +var _ edgePolicyEncodingInfo = (*ChannelEdgePolicy1)(nil) // Signature is a channel announcement signature, which is needed for proper // edge policy announcement. @@ -3738,6 +3784,105 @@ func (c *ChannelEdgePolicy1) ComputeFeeFromIncoming( ) } +// edgePolicyEncoding indicates how the bytes for a channel edge policy have +// been serialised. +type edgePolicyEncodingType uint8 + +const ( + // edgePolicy2EncodingType will be used as a prefix for edge policies + // advertised using the ChannelUpdate2 message. The type indicates how + // the bytes following should be deserialized. + edgePolicy2EncodingType edgePolicyEncodingType = 0 +) + +type ChannelEdgePolicy2 struct { + lnwire.ChannelUpdate2 +} + +func (c *ChannelEdgePolicy2) typeByte() (edgePolicyEncodingType, bool) { + return edgePolicy2EncodingType, true +} + +func (c *ChannelEdgePolicy2) updateBucketKey() []byte { + return edgeUpdateIndex2Bucket +} + +func (c *ChannelEdgePolicy2) updateKey() []byte { + indexKey := make([]byte, 4+8) + byteOrder.PutUint32(indexKey[:4], c.BlockHeight) + byteOrder.PutUint64(indexKey[8:], c.ShortChannelID.ToUint64()) + + return indexKey +} + +const ( + EdgePolicy2MsgType = tlv.Type(0) + EdgePolicy2ToNode = tlv.Type(1) +) + +func (c *ChannelEdgePolicy2) serialize(w io.Writer, toNode []byte) error { + if len(c.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(c.ExtraOpaqueData)) + } + + var b bytes.Buffer + if err := c.Encode(&b, 0); err != nil { + return err + } + + msg := b.Bytes() + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func deserializeChanEdgePolicy2Raw(r io.Reader) (*ChannelEdgePolicy2, []byte, + error) { + + var ( + msgBytes []byte + toNode []byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, nil, err + } + + err = stream.Decode(r) + if err != nil { + return nil, nil, err + } + + var ( + policy ChannelEdgePolicy2 + reader = bytes.NewReader(msgBytes) + ) + err = policy.Decode(reader, 0) + if err != nil { + return nil, nil, err + } + + return &policy, toNode, nil +} + +var _ models.ChannelEdgePolicy = (*ChannelEdgePolicy2)(nil) +var _ edgePolicyEncodingInfo = (*ChannelEdgePolicy2)(nil) + func EdgePolicyFromUpdate(update lnwire.ChannelUpdate) ( *ChannelEdgePolicyWithNode, error) { @@ -4789,32 +4934,46 @@ func deserializeChanEdgeInfo1(r io.Reader) (*ChannelEdgeInfo1, error) { return &edgeInfo, nil } -func putChanEdgePolicy(edges kvdb.RwBucket, edgePolicy models.ChannelEdgePolicy, +func putChanEdgePolicy(edges kvdb.RwBucket, edge models.ChannelEdgePolicy, from, to []byte) error { - edge, ok := edgePolicy.(*ChannelEdgePolicy1) + encodingInfo, ok := edge.(edgePolicyEncodingInfo) if !ok { - return fmt.Errorf("unhandled implementation of "+ - "ChannelEdgePolicy: %T", edgePolicy) + return fmt.Errorf("type %T does not implement "+ + "edgePolicyEncodingInfo", edge) } + chanID := edge.SCID().ToUint64() + var edgeKey [33 + 8]byte copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) + byteOrder.PutUint64(edgeKey[33:], chanID) var b bytes.Buffer - if err := serializeChanEdgePolicy(&b, edge, to); err != nil { + + if typeByte, ok := encodingInfo.typeByte(); ok { + _, err := b.Write([]byte{chanEdgePolicyNewEncodingPrefix}) + if err != nil { + return err + } + + _, err = b.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + if err := encodingInfo.serialize(&b, to); err != nil { return err } // Before we write out the new edge, we'll create a new entry in the // update index in order to keep it fresh. - updateUnix := uint64(edge.LastUpdate.Unix()) - var indexKey [8 + 8]byte - byteOrder.PutUint64(indexKey[:8], updateUnix) - byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + indexKey := encodingInfo.updateKey() - updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + updateIndex, err := edges.CreateBucketIfNotExists( + encodingInfo.updateBucketKey(), + ) if err != nil { return err } @@ -4842,11 +5001,13 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edgePolicy models.ChannelEdgePolicy, return err } - oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) + oldPolicy, ok := oldEdgePolicy.(edgePolicyEncodingInfo) + if !ok { + return fmt.Errorf("type %T does not implement "+ + "edgePolicyEncodingInfo", oldEdgePolicy) + } - var oldIndexKey [8 + 8]byte - byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) - byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + oldIndexKey := oldPolicy.updateKey() if err := updateIndex.Delete(oldIndexKey[:]); err != nil { return err @@ -4857,11 +5018,12 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edgePolicy models.ChannelEdgePolicy, return err } - updateEdgePolicyDisabledIndex( - edges, edge.ChannelID, - edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, - edge.IsDisabled(), + err = updateEdgePolicyDisabledIndex( + edges, chanID, !edge.IsNode1(), edge.IsDisabled(), ) + if err != nil { + return err + } return edges.Put(edgeKey[:], b.Bytes()[:]) } @@ -4983,7 +5145,7 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return edge1, edge2, nil } -func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy1, +func serializeChanEdgePolicy1(w io.Writer, edge *ChannelEdgePolicy1, to []byte) error { err := wire.WriteVarBytes(w, 0, edge.SigBytes) @@ -5069,15 +5231,21 @@ func deserializeChanEdgePolicy(r io.Reader, pubKeyBytes, err) } + e, ok := edge.(*ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected type ChannelEdgePolicy1, "+ + "got: %T", edge) + } + policy := ChannelEdgePolicyWithNode{ - ChannelEdgePolicy1: *edge, + ChannelEdgePolicy1: *e, Node: &node, } return &policy, deserializeErr } -func deserializeChanEdgePolicyRaw(reader io.Reader) (*ChannelEdgePolicy1, +func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy, []byte, error) { // Wrap the io.Reader in a bufio.Reader so that we can peak the first @@ -5093,8 +5261,26 @@ func deserializeChanEdgePolicyRaw(reader io.Reader) (*ChannelEdgePolicy1, return deserializeChanEdgePolicy1Raw(r) } - return nil, nil, fmt.Errorf("unknown channel edge policy encoding "+ - "type: %x", firstByte[0]) + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, nil, err + } + + encoding := edgePolicyEncodingType(scratch[0]) + switch encoding { + case edgePolicy2EncodingType: + return deserializeChanEdgePolicy2Raw(r) + + default: + return nil, nil, fmt.Errorf("unknown edge policy encoding "+ + "type: %d", encoding) + } } func deserializeChanEdgePolicy1Raw(r io.Reader) (*ChannelEdgePolicy1, []byte, diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 03f7076108..d4d0e0a12d 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -3024,7 +3024,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { edge1.ExtraOpaqueData = nil var b bytes.Buffer - err = serializeChanEdgePolicy(&b, &edge1.ChannelEdgePolicy1, to) + err = serializeChanEdgePolicy1(&b, &edge1.ChannelEdgePolicy1, to) if err != nil { t.Fatalf("unable to serialize policy") } @@ -3034,7 +3034,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { edge1.MessageFlags = lnwire.ChanUpdateRequiredMaxHtlc edge1.MaxHTLC = 13928598 var b2 bytes.Buffer - err = serializeChanEdgePolicy(&b2, &edge1.ChannelEdgePolicy1, to) + err = serializeChanEdgePolicy1(&b2, &edge1.ChannelEdgePolicy1, to) if err != nil { t.Fatalf("unable to serialize policy") } diff --git a/channeldb/models/interfaces.go b/channeldb/models/interfaces.go index 8c4953d1b1..38ebdb150f 100644 --- a/channeldb/models/interfaces.go +++ b/channeldb/models/interfaces.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwire" ) type ChannelEdgeInfo interface { //nolint:interfacebloat @@ -52,4 +53,7 @@ type ChannelAuthProof interface { } type ChannelEdgePolicy interface { + SCID() lnwire.ShortChannelID + IsDisabled() bool + IsNode1() bool }