diff --git a/channeldb/edge_info.go b/channeldb/edge_info.go index 694f0c746a8..92fa0a657e3 100644 --- a/channeldb/edge_info.go +++ b/channeldb/edge_info.go @@ -1,34 +1,136 @@ package channeldb import ( + "bufio" "bytes" "encoding/binary" + "fmt" "io" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" ) +// edgeInfoEncodingType indicate how the bytes for a channel edge have been +// serialised. +type edgeInfoEncodingType uint8 + +const ( + // edgeInfo2EncodingType will be used as a prefix for edge's advertised + // using the ChannelAnnouncement2 message. The type indicates how the + // bytes following should be deserialized. + edgeInfo2EncodingType edgeInfoEncodingType = 0 +) + +const ( + // EdgeInfo2MsgType is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the serialisation of the associated + // lnwire.ChannelAnnouncement2 message. + EdgeInfo2MsgType = tlv.Type(0) + + // EdgeInfo2Sig is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the signature of the + // lnwire.ChannelAnnouncement2 message. + EdgeInfo2Sig = tlv.Type(1) + + // EdgeInfo2ChanPoint is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing channel point. + EdgeInfo2ChanPoint = tlv.Type(2) + + // EdgeInfo2PKScript is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the funding pk script of the channel. + EdgeInfo2PKScript = tlv.Type(3) +) + +const ( + // chanEdgeNewEncodingPrefix is a byte used in the channel edge encoding + // to signal that the new style encoding which is prefixed with a type + // byte is being used instead of the legacy encoding which would start + // with either 0x02 or 0x03 due to the fact that the encoding would + // start with a node's compressed public key. + chanEdgeNewEncodingPrefix = 0xff +) + +// putChanEdgeInfo serialises the given ChannelEdgeInfo and writes the result +// to the edgeIndex using the channel ID as a key. func putChanEdgeInfo(edgeIndex kvdb.RwBucket, - edgeInfo *models.ChannelEdgeInfo1, chanID [8]byte) error { + edgeInfo models.ChannelEdgeInfo) error { + + var ( + chanID [8]byte + b bytes.Buffer + ) + + binary.BigEndian.PutUint64(chanID[:], edgeInfo.GetChanID()) + + if err := serializeChanEdgeInfo(&b, edgeInfo); err != nil { + return err + } + + return edgeIndex.Put(chanID[:], b.Bytes()) +} + +func serializeChanEdgeInfo(w io.Writer, edgeInfo models.ChannelEdgeInfo) error { + var ( + withTypeByte bool + typeByte edgeInfoEncodingType + serialize func(w io.Writer) error + ) - var b bytes.Buffer + switch info := edgeInfo.(type) { + case *models.ChannelEdgeInfo1: + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo1(w, info) + } + case *models.ChannelEdgeInfo2: + withTypeByte = true + typeByte = edgeInfo2EncodingType - if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo2(w, info) + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgeInfo: %T", edgeInfo) + } + + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := w.Write([]byte{chanEdgeNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = w.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + return serialize(w) +} + +func serializeChanEdgeInfo1(w io.Writer, + edgeInfo *models.ChannelEdgeInfo1) error { + + if _, err := w.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { + if _, err := w.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { + if err := wire.WriteVarBytes(w, 0, edgeInfo.Features); err != nil { return err } @@ -41,96 +143,176 @@ func putChanEdgeInfo(edgeIndex kvdb.RwBucket, bitcoinSig2 = authProof.BitcoinSig2Bytes } - if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { + if err := wire.WriteVarBytes(w, 0, nodeSig1); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { + if err := wire.WriteVarBytes(w, 0, nodeSig2); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { + if err := wire.WriteVarBytes(w, 0, bitcoinSig1); err != nil { return err } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { + if err := wire.WriteVarBytes(w, 0, bitcoinSig2); err != nil { return err } - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + if err := writeOutpoint(w, &edgeInfo.ChannelPoint); err != nil { return err } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { + err := binary.Write(w, byteOrder, uint64(edgeInfo.Capacity)) + if err != nil { return err } - if _, err := b.Write(chanID[:]); err != nil { + + var chanID [8]byte + binary.BigEndian.PutUint64(chanID[:], edgeInfo.ChannelID) + if _, err := w.Write(chanID[:]); err != nil { return err } - if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { + if _, err := w.Write(edgeInfo.ChainHash[:]); err != nil { return err } if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) + + return wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) +} + +func serializeChanEdgeInfo2(w io.Writer, edge *models.ChannelEdgeInfo2) error { + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + serializedMsg, err := edge.DataToSign() if err != nil { return err } - return edgeIndex.Put(chanID[:], b.Bytes()) + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &serializedMsg), + } + + if edge.AuthProof != nil { + records = append( + records, tlv.MakePrimitiveRecord( + EdgeInfo2Sig, &edge.AuthProof.SchnorrSigBytes, + ), + ) + } + + records = append(records, tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edge.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + )) + + if len(edge.FundingPkScript) != 0 { + records = append( + records, tlv.MakePrimitiveRecord( + EdgeInfo2PKScript, &edge.FundingPkScript, + ), + ) + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) } func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, - chanID []byte) (models.ChannelEdgeInfo1, error) { + chanID []byte) (models.ChannelEdgeInfo, error) { edgeInfoBytes := edgeIndex.Get(chanID) if edgeInfoBytes == nil { - return models.ChannelEdgeInfo1{}, ErrEdgeNotFound + return nil, ErrEdgeNotFound } edgeInfoReader := bytes.NewReader(edgeInfoBytes) + return deserializeChanEdgeInfo(edgeInfoReader) } -func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo1, error) { +func deserializeChanEdgeInfo(reader io.Reader) (models.ChannelEdgeInfo, error) { + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgeNewEncodingPrefix { + return deserializeChanEdgeInfo1(r) + } + + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgeInfoEncodingType(scratch[0]) + switch encoding { + case edgeInfo2EncodingType: + return deserializeChanEdgeInfo2(r) + + default: + return nil, fmt.Errorf("unknown edge info encoding type: %d", + encoding) + } +} + +func deserializeChanEdgeInfo1(r io.Reader) (*models.ChannelEdgeInfo1, error) { var ( err error edgeInfo models.ChannelEdgeInfo1 ) if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") if err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } proof := &models.ChannelAuthProof1{} proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") if err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if !proof.IsEmpty() { @@ -139,17 +321,17 @@ func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo1, error) { edgeInfo.ChannelPoint = wire.OutPoint{} if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { - return models.ChannelEdgeInfo1{}, err + return nil, err } // We'll try and see if there are any opaque bytes left, if not, then @@ -161,8 +343,68 @@ func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo1, error) { case err == io.ErrUnexpectedEOF: case err == io.EOF: case err != nil: - return models.ChannelEdgeInfo1{}, err + return nil, err + } + + return &edgeInfo, nil +} + +func deserializeChanEdgeInfo2(r io.Reader) (*models.ChannelEdgeInfo2, error) { + var ( + edgeInfo models.ChannelEdgeInfo2 + msgBytes []byte + sigBytes []byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgeInfo2Sig, &sigBytes), + tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edgeInfo.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + ), + tlv.MakePrimitiveRecord( + EdgeInfo2PKScript, &edgeInfo.FundingPkScript, + ), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + typeMap, err := stream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + reader := bytes.NewReader(msgBytes) + err = edgeInfo.ChannelAnnouncement2.DecodeTLVRecords(reader) + if err != nil { + return nil, err + } + + if _, ok := typeMap[EdgeInfo2Sig]; ok { + edgeInfo.AuthProof = &models.ChannelAuthProof2{ + SchnorrSigBytes: sigBytes, + } + } + + return &edgeInfo, nil +} + +func encodeOutpoint(w io.Writer, val interface{}, _ *[8]byte) error { + if o, ok := val.(*wire.OutPoint); ok { + return writeOutpoint(w, o) + } + + return tlv.NewTypeForEncodingErr(val, "*wire.Outpoint") +} + +func decodeOutpoint(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if o, ok := val.(*wire.OutPoint); ok { + return readOutpoint(r, o) } - return edgeInfo, nil + return tlv.NewTypeForDecodingErr(val, "*wire.Outpoint", l, l) } diff --git a/channeldb/edge_info_test.go b/channeldb/edge_info_test.go new file mode 100644 index 00000000000..4fa6d5df189 --- /dev/null +++ b/channeldb/edge_info_test.go @@ -0,0 +1,264 @@ +package channeldb + +import ( + "bytes" + "encoding/hex" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + testSchnorrSigStr, _ = hex.DecodeString("04E7F9037658A92AFEB4F2" + + "5BAE5339E3DDCA81A353493827D26F16D92308E49E2A25E9220867" + + "8A2DF86970DA91B03A8AF8815A8A60498B358DAF560B347AA557") + testSchnorrSig, _ = lnwire.NewSigFromSchnorrRawSignature( + testSchnorrSigStr, + ) +) + +// TestEdgeInfoSerialisation tests the serialisation and deserialization logic +// for models.ChannelEdgeInfo. +func TestEdgeInfoSerialisation(t *testing.T) { + t.Parallel() + + mainScenario := func(info models.ChannelEdgeInfo) bool { + var b bytes.Buffer + err := serializeChanEdgeInfo(&b, info) + require.NoError(t, err) + + newInfo, err := deserializeChanEdgeInfo(&b) + require.NoError(t, err) + + return assert.Equal(t, info, newInfo) + } + + tests := []struct { + name string + genValue func([]reflect.Value, *rand.Rand) + scenario any + }{ + { + name: "ChannelEdgeInfo1", + scenario: func(m models.ChannelEdgeInfo1) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + info := &models.ChannelEdgeInfo1{ + ChannelID: r.Uint64(), + NodeKey1Bytes: randRawKey(t), + NodeKey2Bytes: randRawKey(t), + BitcoinKey1Bytes: randRawKey(t), + BitcoinKey2Bytes: randRawKey(t), + ChannelPoint: wire.OutPoint{ + Index: r.Uint32(), + }, + Capacity: btcutil.Amount( + r.Uint32(), + ), + ExtraOpaqueData: make([]byte, 0), + } + + _, err := r.Read(info.ChainHash[:]) + require.NoError(t, err) + + _, err = r.Read(info.ChannelPoint.Hash[:]) + require.NoError(t, err) + + info.Features = make([]byte, r.Intn(900)) + _, err = r.Read(info.Features) + require.NoError(t, err) + + // Sometimes add an AuthProof. + if r.Intn(2)%2 == 0 { + n := r.Intn(80) + + //nolint:lll + authProof := &models.ChannelAuthProof1{ + NodeSig1Bytes: make([]byte, n), + NodeSig2Bytes: make([]byte, n), + BitcoinSig1Bytes: make([]byte, n), + BitcoinSig2Bytes: make([]byte, n), + } + + _, err = r.Read( + authProof.NodeSig1Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.NodeSig2Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.BitcoinSig1Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.BitcoinSig2Bytes, + ) + require.NoError(t, err) + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + info.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + info.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(*info) + }, + }, + { + name: "ChannelEdgeInfo2", + scenario: func(m models.ChannelEdgeInfo2) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + ann := lnwire.ChannelAnnouncement2{ + ExtraOpaqueData: make([]byte, 0), + } + + features := randFeatureVector(r) + ann.Features.Val = *features + + scid := lnwire.NewShortChanIDFromInt(r.Uint64()) + ann.ShortChannelID.Val = scid + ann.Capacity.Val = rand.Uint64() + ann.NodeID1.Val = randRawKey(t) + ann.NodeID2.Val = randRawKey(t) + + // Sometimes set chain hash to bitcoin mainnet + // genesis hash. + ann.ChainHash.Val = *chaincfg.MainNetParams. + GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(ann.ChainHash.Val[:]) + require.NoError(t, err) + } + + if r.Intn(2)%2 == 0 { + btcKey1 := tlv.ZeroRecordT[ + tlv.TlvType12, [33]byte, + ]() + btcKey1.Val = randRawKey(t) + ann.BitcoinKey1 = tlv.SomeRecordT( + btcKey1, + ) + + btcKey2 := tlv.ZeroRecordT[ + tlv.TlvType14, [33]byte, + ]() + btcKey2.Val = randRawKey(t) + ann.BitcoinKey2 = tlv.SomeRecordT( + btcKey2, + ) + } + + if r.Intn(2)%2 == 0 { + hash := tlv.ZeroRecordT[ + tlv.TlvType16, [32]byte, + ]() + + _, err := r.Read(hash.Val[:]) + require.NoError(t, err) + + ann.MerkleRootHash = tlv.SomeRecordT( + hash, + ) + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + ann.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(ann.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + info := &models.ChannelEdgeInfo2{ + ChannelAnnouncement2: ann, + ChannelPoint: wire.OutPoint{ + Index: r.Uint32(), + }, + } + + _, err := r.Read(info.ChannelPoint.Hash[:]) + require.NoError(t, err) + + if r.Intn(2)%2 == 0 { + authProof := &models.ChannelAuthProof2{ + SchnorrSigBytes: testSchnorrSigStr, //nolint:lll + } + + info.AuthProof = authProof + } + + if r.Intn(2)%2 == 0 { + var pkScript [34]byte + _, err := r.Read(pkScript[:]) + require.NoError(t, err) + + info.FundingPkScript = pkScript[:] + } + + v[0] = reflect.ValueOf(*info) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + config := &quick.Config{ + Values: test.genValue, + } + + err := quick.Check(test.scenario, config) + require.NoError(t, err) + }) + } +} + +func randRawKey(t *testing.T) [33]byte { + var n [33]byte + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + copy(n[:], priv.PubKey().SerializeCompressed()) + + return n +} + +func randFeatureVector(r *rand.Rand) *lnwire.FeatureVector { + featureVec := lnwire.NewRawFeatureVector() + for i := 0; i < 10000; i++ { + if r.Int31n(2) == 0 { + featureVec.Set(lnwire.FeatureBit(i)) + } + } + + return lnwire.NewFeatureVector(featureVec, lnwire.Features) +} diff --git a/channeldb/graph.go b/channeldb/graph.go index 82abbe3f256..77a65fa8ef9 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -444,16 +444,23 @@ func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo1, } policy1 := channelMap[channelMapKey{ - nodeKey: info.NodeKey1Bytes, + nodeKey: info.Node1Bytes(), chanID: chanID, }] policy2 := channelMap[channelMapKey{ - nodeKey: info.NodeKey2Bytes, + nodeKey: info.Node2Bytes(), chanID: chanID, }] - return cb(&info, policy1, policy2) + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + return cb(edgeInfo, policy1, policy2) }) }, func() {}) } @@ -1095,7 +1102,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, // If the edge hasn't been created yet, then we'll first add it to the // edge index in order to associate the edge between two nodes and also // store the static components of the channel. - if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { + if err := putChanEdgeInfo(edgeIndex, edge); err != nil { return err } @@ -1259,7 +1266,7 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *models.ChannelEdgeInfo1) error { c.graphCache.UpdateChannel(edge) } - return putChanEdgeInfo(edgeIndex, edge, chanKey) + return putChanEdgeInfo(edgeIndex, edge) }, func() {}) } @@ -1353,7 +1360,14 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - chansClosed = append(chansClosed, &edgeInfo) + info, ok := edgeInfo.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + chansClosed = append(chansClosed, info) } metaBucket, err := tx.CreateTopLevelBucket(graphMetaBucket) @@ -1472,17 +1486,17 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // edge info. We'll use this scan to populate our reference count map // above. err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - // The first 66 bytes of the edge info contain the pubkeys of - // the nodes that this edge attaches. We'll extract them, and - // add them to the ref count map. - var node1, node2 [33]byte - copy(node1[:], edgeInfoBytes[:33]) - copy(node2[:], edgeInfoBytes[33:]) + edge, err := deserializeChanEdgeInfo( + bytes.NewReader(edgeInfoBytes), + ) + if err != nil { + return err + } // With the nodes extracted, we'll increase the ref count of // each of the nodes. - nodeRefCounts[node1]++ - nodeRefCounts[node2]++ + nodeRefCounts[edge.Node1Bytes()]++ + nodeRefCounts[edge.Node2Bytes()]++ return nil }) @@ -1602,7 +1616,15 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( } keys = append(keys, k) - removedChans = append(removedChans, &edgeInfo) + info, ok := edgeInfo.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + keys = append(keys, k) + removedChans = append(removedChans, info) } for _, k := range keys { @@ -1958,13 +1980,20 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, } // First, we'll fetch the static edge information. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { chanID := byteOrder.Uint64(chanID) return fmt.Errorf("unable to fetch info for "+ "edge with chan_id=%v: %v", chanID, err) } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + // With the static information obtained, we'll now // fetch the dynamic policy info. edge1, edge2, err := fetchChanEdgePolicies( @@ -1995,7 +2024,7 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // edges to be returned. edgesSeen[chanIDInt] = struct{}{} channel := ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, Node1: &node1, @@ -2313,7 +2342,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - if edgeInfo.AuthProof == nil { + if edgeInfo.GetAuthProof() == nil { continue } @@ -2335,7 +2364,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, continue } - node1Key, node2Key := computeEdgePolicyKeys(&edgeInfo) + node1Key, node2Key := computeEdgePolicyKeys(edgeInfo) rawPolicy := edges.Get(node1Key) if len(rawPolicy) != 0 { @@ -2453,7 +2482,7 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( // First, we'll fetch the static edge information. If // the edge is unknown, we will skip the edge and // continue gathering all known edges. - edgeInfo, err := fetchChanEdgeInfo( + info, err := fetchChanEdgeInfo( edgeIndex, cidBytes[:], ) switch { @@ -2472,6 +2501,13 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + info) + } + node1, err := fetchLightningNode( nodes, edgeInfo.NodeKey1Bytes[:], ) @@ -2487,7 +2523,7 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( } chanEdges = append(chanEdges, ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, Node1: &node1, @@ -2565,11 +2601,17 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, zombieIndex kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected *models.ChannelEdgeInfo1, got %T", + info) + } + if c.graphCache != nil { c.graphCache.RemoveChannel( edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, @@ -2638,7 +2680,7 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, nodeKey1, nodeKey2 := edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes if strictZombie { - nodeKey1, nodeKey2 = makeZombiePubkeys(&edgeInfo, edge1, edge2) + nodeKey1, nodeKey2 = makeZombiePubkeys(edgeInfo, edge1, edge2) } return markEdgeZombie( @@ -2799,23 +2841,32 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy1, return false, ErrEdgeNotFound } + edgeInfo, err := deserializeChanEdgeInfo(bytes.NewReader(nodeInfo)) + if err != nil { + return false, err + } + // Depending on the flags value passed above, either the first // or second edge policy is being updated. - var fromNode, toNode []byte - var isUpdate1 bool - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { - fromNode = nodeInfo[:33] - toNode = nodeInfo[33:66] + var ( + fromNode, toNode []byte + isUpdate1 bool + node1Bytes = edgeInfo.Node1Bytes() + node2Bytes = edgeInfo.Node2Bytes() + ) + if edge.IsNode1() { + fromNode = node1Bytes[:] + toNode = node2Bytes[:] isUpdate1 = true } else { - fromNode = nodeInfo[33:66] - toNode = nodeInfo[:33] + fromNode = node2Bytes[:] + toNode = node1Bytes[:] isUpdate1 = false } // Finally, with the direction of the edge being updated // identified, we update the on-disk edge representation. - err := putChanEdgePolicy(edges, edge, fromNode, toNode) + err = putChanEdgePolicy(edges, edge, fromNode, toNode) if err != nil { return false, err } @@ -3209,11 +3260,18 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // the node at the other end of the channel and both // edge policies. chanID := nodeEdge[33:] - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + outgoingPolicy, err := fetchChanEdgePolicy( edges, chanID, nodePub, ) @@ -3234,7 +3292,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, } // Finally, we execute the callback. - err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) + err = cb(tx, edgeInfo, outgoingPolicy, incomingPolicy) if err != nil { return err } @@ -3343,17 +3401,20 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, // computeEdgePolicyKeys is a helper function that can be used to compute the // keys used to index the channel edge policy info for the two nodes of the // edge. The keys for node 1 and node 2 are returned respectively. -func computeEdgePolicyKeys(info *models.ChannelEdgeInfo1) ([]byte, []byte) { +func computeEdgePolicyKeys(info models.ChannelEdgeInfo) ([]byte, []byte) { var ( node1Key [33 + 8]byte node2Key [33 + 8]byte + + node1Bytes = info.Node1Bytes() + node2Bytes = info.Node2Bytes() ) - copy(node1Key[:], info.NodeKey1Bytes[:]) - copy(node2Key[:], info.NodeKey2Bytes[:]) + copy(node1Key[:], node1Bytes[:]) + copy(node2Key[:], node2Bytes[:]) - byteOrder.PutUint64(node1Key[33:], info.ChannelID) - byteOrder.PutUint64(node2Key[33:], info.ChannelID) + byteOrder.PutUint64(node1Key[33:], info.GetChanID()) + byteOrder.PutUint64(node2Key[33:], info.GetChanID()) return node1Key[:], node2Key[:] } @@ -3414,7 +3475,15 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( if err != nil { return fmt.Errorf("%w: chanID=%x", err, chanID) } - edgeInfo = &edge + + info, ok := edge.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edge) + } + + edgeInfo = info // Once we have the information about the channels' parameters, // we'll fetch the routing policies for each for the directed @@ -3518,7 +3587,13 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( return err } - edgeInfo = &edge + info, ok := edge.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", edge) + } + + edgeInfo = info // Then we'll attempt to fetch the accompanying policies of this // edge. @@ -3658,10 +3733,7 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return err } - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) + pkScript, err := edgeInfo.FundingScript() if err != nil { return err }