diff --git a/autopilot/graph.go b/autopilot/graph.go index 2ce49c1272..3939d38629 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -90,8 +90,8 @@ func (d *dbNode) Addrs() []net.Addr { // NOTE: Part of the autopilot.Node interface. func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, - func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep, - _ *models.ChannelEdgePolicy) error { + func(tx kvdb.RTx, ei *models.ChannelEdgeInfo1, ep, + _ *models.ChannelEdgePolicy1) error { // Skip channels for which no outgoing edge policy is // available. @@ -238,7 +238,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, } chanID := randChanID() - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), Capacity: capacity, } @@ -246,7 +246,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err := d.db.AddChannelEdge(edge); err != nil { return nil, nil, err } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), @@ -262,7 +262,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { return nil, nil, err } - edgePolicy = &models.ChannelEdgePolicy{ + edgePolicy = &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Now(), diff --git a/channeldb/channel_cache_test.go b/channeldb/channel_cache_test.go index 7cb857293b..d32548a281 100644 --- a/channeldb/channel_cache_test.go +++ b/channeldb/channel_cache_test.go @@ -100,7 +100,7 @@ func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { // channelForInt generates a unique ChannelEdge given an integer. func channelForInt(i uint64) ChannelEdge { return ChannelEdge{ - Info: &models.ChannelEdgeInfo{ + Info: &models.ChannelEdgeInfo1{ ChannelID: i, }, } diff --git a/channeldb/edge_info.go b/channeldb/edge_info.go new file mode 100644 index 0000000000..fbf5525770 --- /dev/null +++ b/channeldb/edge_info.go @@ -0,0 +1,411 @@ +package channeldb + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +// edgeInfoEncodingType indicate how the bytes for a channel edge have been +// serialised. +type edgeInfoEncodingType uint8 + +const ( + // edgeInfo2EncodingType will be used as a prefix for edge's advertised + // using the ChannelAnnouncement2 message. The type indicates how the + // bytes following should be deserialized. + edgeInfo2EncodingType edgeInfoEncodingType = 0 +) + +const ( + // EdgeInfo2MsgType is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the serialisation of the associated + // lnwire.ChannelAnnouncement2 message. + EdgeInfo2MsgType = tlv.Type(0) + + // EdgeInfo2Sig is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the signature of the + // lnwire.ChannelAnnouncement2 message. + EdgeInfo2Sig = tlv.Type(1) + + // EdgeInfo2ChanPoint is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing channel point. + EdgeInfo2ChanPoint = tlv.Type(2) + + // EdgeInfo2PKScript is the tlv type used within the serialisation of + // ChannelEdgeInfo2 for storing the funding pk script of the channel. + EdgeInfo2PKScript = tlv.Type(3) +) + +const ( + // chanEdgeNewEncodingPrefix is a byte used in the channel edge encoding + // to signal that the new style encoding which is prefixed with a type + // byte is being used instead of the legacy encoding which would start + // with either 0x02 or 0x03 due to the fact that the encoding would + // start with a node's compressed public key. + chanEdgeNewEncodingPrefix = 0xff +) + +// putChanEdgeInfo serialises the given ChannelEdgeInfo and writes the result +// to the edgeIndex using the channel ID as a key. +func putChanEdgeInfo(edgeIndex kvdb.RwBucket, + edgeInfo models.ChannelEdgeInfo) error { + + var ( + chanID [8]byte + b bytes.Buffer + ) + + binary.BigEndian.PutUint64(chanID[:], edgeInfo.GetChanID()) + + if err := serializeChanEdgeInfo(&b, edgeInfo); err != nil { + return err + } + + return edgeIndex.Put(chanID[:], b.Bytes()) +} + +func serializeChanEdgeInfo(w io.Writer, edgeInfo models.ChannelEdgeInfo) error { + var ( + withTypeByte bool + typeByte edgeInfoEncodingType + serialize func(w io.Writer) error + ) + + switch info := edgeInfo.(type) { + case *models.ChannelEdgeInfo1: + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo1(w, info) + } + case *models.ChannelEdgeInfo2: + withTypeByte = true + typeByte = edgeInfo2EncodingType + + serialize = func(w io.Writer) error { + return serializeChanEdgeInfo2(w, info) + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgeInfo: %T", edgeInfo) + } + + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := w.Write([]byte{chanEdgeNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = w.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + return serialize(w) +} + +func serializeChanEdgeInfo1(w io.Writer, + edgeInfo *models.ChannelEdgeInfo1) error { + + if _, err := w.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { + return err + } + if _, err := w.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { + return err + } + if _, err := w.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return err + } + if _, err := w.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, edgeInfo.Features); err != nil { + return err + } + + authProof := edgeInfo.AuthProof + var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte + if authProof != nil { + nodeSig1 = authProof.NodeSig1Bytes + nodeSig2 = authProof.NodeSig2Bytes + bitcoinSig1 = authProof.BitcoinSig1Bytes + bitcoinSig2 = authProof.BitcoinSig2Bytes + } + + if err := wire.WriteVarBytes(w, 0, nodeSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, nodeSig2); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, bitcoinSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, bitcoinSig2); err != nil { + return err + } + + if err := writeOutpoint(w, &edgeInfo.ChannelPoint); err != nil { + return err + } + err := binary.Write(w, byteOrder, uint64(edgeInfo.Capacity)) + if err != nil { + return err + } + + var chanID [8]byte + binary.BigEndian.PutUint64(chanID[:], edgeInfo.ChannelID) + if _, err := w.Write(chanID[:]); err != nil { + return err + } + if _, err := w.Write(edgeInfo.ChainHash[:]); err != nil { + return err + } + + if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) + } + + return wire.WriteVarBytes(w, 0, edgeInfo.ExtraOpaqueData) +} + +func serializeChanEdgeInfo2(w io.Writer, edge *models.ChannelEdgeInfo2) error { + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + serializedMsg, err := edge.DataToSign() + if err != nil { + return err + } + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &serializedMsg), + } + + if edge.AuthProof != nil { + records = append( + records, tlv.MakePrimitiveRecord( + EdgeInfo2Sig, &edge.AuthProof.SchnorrSigBytes, + ), + ) + } + + records = append(records, tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edge.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + )) + + if len(edge.FundingPkScript) != 0 { + records = append( + records, tlv.MakePrimitiveRecord( + EdgeInfo2PKScript, &edge.FundingPkScript, + ), + ) + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, + chanID []byte) (models.ChannelEdgeInfo, error) { + + edgeInfoBytes := edgeIndex.Get(chanID) + if edgeInfoBytes == nil { + return nil, ErrEdgeNotFound + } + + edgeInfoReader := bytes.NewReader(edgeInfoBytes) + + return deserializeChanEdgeInfo(edgeInfoReader) +} + +func deserializeChanEdgeInfo(reader io.Reader) (models.ChannelEdgeInfo, error) { + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgeNewEncodingPrefix { + return deserializeChanEdgeInfo1(r) + } + + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgeInfoEncodingType(scratch[0]) + switch encoding { + case edgeInfo2EncodingType: + return deserializeChanEdgeInfo2(r) + + default: + return nil, fmt.Errorf("unknown edge info encoding type: %d", + encoding) + } +} + +func deserializeChanEdgeInfo1(r io.Reader) (*models.ChannelEdgeInfo1, error) { + var ( + err error + edgeInfo models.ChannelEdgeInfo1 + ) + + if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { + return nil, err + } + if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { + return nil, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return nil, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return nil, err + } + + edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") + if err != nil { + return nil, err + } + + proof := &models.ChannelAuthProof1{} + + proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return nil, err + } + proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return nil, err + } + proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return nil, err + } + proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return nil, err + } + + if !proof.IsEmpty() { + edgeInfo.AuthProof = proof + } + + edgeInfo.ChannelPoint = wire.OutPoint{} + if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { + return nil, err + } + + if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { + return nil, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case errors.Is(err, io.ErrUnexpectedEOF): + case errors.Is(err, io.EOF): + case err != nil: + return nil, err + } + + return &edgeInfo, nil +} + +func deserializeChanEdgeInfo2(r io.Reader) (*models.ChannelEdgeInfo2, error) { + var ( + edgeInfo models.ChannelEdgeInfo2 + msgBytes []byte + sigBytes []byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgeInfo2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgeInfo2Sig, &sigBytes), + tlv.MakeStaticRecord( + EdgeInfo2ChanPoint, &edgeInfo.ChannelPoint, 34, + encodeOutpoint, decodeOutpoint, + ), + tlv.MakePrimitiveRecord( + EdgeInfo2PKScript, &edgeInfo.FundingPkScript, + ), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + typeMap, err := stream.DecodeWithParsedTypes(r) + if err != nil { + return nil, err + } + + reader := bytes.NewReader(msgBytes) + err = edgeInfo.ChannelAnnouncement2.DecodeTLVRecords(reader) + if err != nil { + return nil, err + } + + if _, ok := typeMap[EdgeInfo2Sig]; ok { + edgeInfo.AuthProof = &models.ChannelAuthProof2{ + SchnorrSigBytes: sigBytes, + } + } + + return &edgeInfo, nil +} + +func encodeOutpoint(w io.Writer, val interface{}, _ *[8]byte) error { + if o, ok := val.(*wire.OutPoint); ok { + return writeOutpoint(w, o) + } + + return tlv.NewTypeForEncodingErr(val, "*wire.Outpoint") +} + +func decodeOutpoint(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if o, ok := val.(*wire.OutPoint); ok { + return readOutpoint(r, o) + } + + return tlv.NewTypeForDecodingErr(val, "*wire.Outpoint", l, l) +} diff --git a/channeldb/edge_info_test.go b/channeldb/edge_info_test.go new file mode 100644 index 0000000000..9a0ecf732c --- /dev/null +++ b/channeldb/edge_info_test.go @@ -0,0 +1,264 @@ +package channeldb + +import ( + "bytes" + "encoding/hex" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + testSchnorrSigStr, _ = hex.DecodeString("04E7F9037658A92AFEB4F2" + + "5BAE5339E3DDCA81A353493827D26F16D92308E49E2A25E9220867" + + "8A2DF86970DA91B03A8AF8815A8A60498B358DAF560B347AA557") + testSchnorrSig, _ = lnwire.NewSigFromSchnorrRawSignature( + testSchnorrSigStr, + ) +) + +// TestEdgeInfoSerialisation tests the serialisation and deserialization logic +// for models.ChannelEdgeInfo. +func TestEdgeInfoSerialisation(t *testing.T) { + t.Parallel() + + mainScenario := func(info models.ChannelEdgeInfo) bool { + var b bytes.Buffer + err := serializeChanEdgeInfo(&b, info) + require.NoError(t, err) + + newInfo, err := deserializeChanEdgeInfo(&b) + require.NoError(t, err) + + return assert.Equal(t, info, newInfo) + } + + tests := []struct { + name string + genValue func([]reflect.Value, *rand.Rand) + scenario any + }{ + { + name: "ChannelEdgeInfo1", + scenario: func(m models.ChannelEdgeInfo1) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + info := &models.ChannelEdgeInfo1{ + ChannelID: r.Uint64(), + NodeKey1Bytes: randRawKey(t), + NodeKey2Bytes: randRawKey(t), + BitcoinKey1Bytes: randRawKey(t), + BitcoinKey2Bytes: randRawKey(t), + ChannelPoint: wire.OutPoint{ + Index: r.Uint32(), + }, + Capacity: btcutil.Amount( + r.Uint32(), + ), + ExtraOpaqueData: make([]byte, 0), + } + + _, err := r.Read(info.ChainHash[:]) + require.NoError(t, err) + + _, err = r.Read(info.ChannelPoint.Hash[:]) + require.NoError(t, err) + + info.Features = make([]byte, r.Intn(900)) + _, err = r.Read(info.Features) + require.NoError(t, err) + + // Sometimes add an AuthProof. + if r.Intn(2)%2 == 0 { + n := r.Intn(80) + + //nolint:lll + authProof := &models.ChannelAuthProof1{ + NodeSig1Bytes: make([]byte, n), + NodeSig2Bytes: make([]byte, n), + BitcoinSig1Bytes: make([]byte, n), + BitcoinSig2Bytes: make([]byte, n), + } + + _, err = r.Read( + authProof.NodeSig1Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.NodeSig2Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.BitcoinSig1Bytes, + ) + require.NoError(t, err) + + _, err = r.Read( + authProof.BitcoinSig2Bytes, + ) + require.NoError(t, err) + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + info.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + info.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(*info) + }, + }, + { + name: "ChannelEdgeInfo2", + scenario: func(m models.ChannelEdgeInfo2) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + ann := lnwire.ChannelAnnouncement2{ + ExtraOpaqueData: make([]byte, 0), + } + + features := randRawFeatureVector(r) + ann.Features.Val = *features + + scid := lnwire.NewShortChanIDFromInt(r.Uint64()) + ann.ShortChannelID.Val = scid + ann.Capacity.Val = rand.Uint64() + ann.NodeID1.Val = randRawKey(t) + ann.NodeID2.Val = randRawKey(t) + + // Sometimes set chain hash to bitcoin mainnet + // genesis hash. + ann.ChainHash.Val = *chaincfg.MainNetParams. + GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(ann.ChainHash.Val[:]) + require.NoError(t, err) + } + + if r.Intn(2)%2 == 0 { + btcKey1 := tlv.ZeroRecordT[ + tlv.TlvType12, [33]byte, + ]() + btcKey1.Val = randRawKey(t) + ann.BitcoinKey1 = tlv.SomeRecordT( + btcKey1, + ) + + btcKey2 := tlv.ZeroRecordT[ + tlv.TlvType14, [33]byte, + ]() + btcKey2.Val = randRawKey(t) + ann.BitcoinKey2 = tlv.SomeRecordT( + btcKey2, + ) + } + + if r.Intn(2)%2 == 0 { + hash := tlv.ZeroRecordT[ + tlv.TlvType16, [32]byte, + ]() + + _, err := r.Read(hash.Val[:]) + require.NoError(t, err) + + ann.MerkleRootHash = tlv.SomeRecordT( + hash, + ) + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + ann.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(ann.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + info := &models.ChannelEdgeInfo2{ + ChannelAnnouncement2: ann, + ChannelPoint: wire.OutPoint{ + Index: r.Uint32(), + }, + } + + _, err := r.Read(info.ChannelPoint.Hash[:]) + require.NoError(t, err) + + if r.Intn(2)%2 == 0 { + authProof := &models.ChannelAuthProof2{ + SchnorrSigBytes: testSchnorrSigStr, //nolint:lll + } + + info.AuthProof = authProof + } + + if r.Intn(2)%2 == 0 { + var pkScript [34]byte + _, err := r.Read(pkScript[:]) + require.NoError(t, err) + + info.FundingPkScript = pkScript[:] + } + + v[0] = reflect.ValueOf(*info) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + config := &quick.Config{ + Values: test.genValue, + } + + err := quick.Check(test.scenario, config) + require.NoError(t, err) + }) + } +} + +func randRawKey(t *testing.T) [33]byte { + var n [33]byte + + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + copy(n[:], priv.PubKey().SerializeCompressed()) + + return n +} + +func randRawFeatureVector(r *rand.Rand) *lnwire.RawFeatureVector { + featureVec := lnwire.NewRawFeatureVector() + for i := 0; i < 10000; i++ { + if r.Int31n(2) == 0 { + featureVec.Set(lnwire.FeatureBit(i)) + } + } + + return featureVec +} diff --git a/channeldb/edge_policy.go b/channeldb/edge_policy.go new file mode 100644 index 0000000000..cc00875010 --- /dev/null +++ b/channeldb/edge_policy.go @@ -0,0 +1,566 @@ +package channeldb + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + EdgePolicy2MsgType = tlv.Type(0) + EdgePolicy2ToNode = tlv.Type(1) + + // chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge + // policy encoding to signal that the new style encoding which is + // prefixed with a type byte is being used instead of the legacy + // encoding which would start with 0x02 due to the fact that the + // encoding would start with a DER encoded ecdsa signature. + chanEdgePolicyNewEncodingPrefix = 0xff +) + +// edgePolicyEncoding indicates how the bytes for a channel edge policy have +// been serialised. +type edgePolicyEncodingType uint8 + +const ( + // edgePolicy2EncodingType will be used as a prefix for edge policies + // advertised using the ChannelUpdate2 message. The type indicates how + // the bytes following should be deserialized. + edgePolicy2EncodingType edgePolicyEncodingType = 0 +) + +func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, + from, to []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) + + var b bytes.Buffer + if err := serializeChanEdgePolicy(&b, edge, to); err != nil { + return err + } + + // Before we write out the new edge, we'll create a new entry in the + // update index in order to keep it fresh. + updateUnix := uint64(edge.LastUpdate.Unix()) + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + // If there was already an entry for this edge, then we'll need to + // delete the old one to ensure we don't leave around any after-images. + // An unknown policy value does not have a update time recorded, so + // it also does not need to be removed. + if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil && + !bytes.Equal(edgeBytes, unknownPolicy) { + + // In order to delete the old entry, we'll need to obtain the + // *prior* update time in order to delete it. To do this, we'll + // need to deserialize the existing policy within the database + // (now outdated by the new one), and delete its corresponding + // entry within the update index. We'll ignore any + // ErrEdgePolicyOptionalFieldNotFound error, as we only need + // the channel ID and update time to delete the entry. + // TODO(halseth): get rid of these invalid policies in a + // migration. + oldEdgePolicy, err := deserializeChanEdgePolicy( + bytes.NewReader(edgeBytes), + ) + if err != nil && + !errors.Is(err, ErrEdgePolicyOptionalFieldNotFound) { + + return err + } + + oldPol, ok := oldEdgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", + oldEdgePolicy) + } + + oldUpdateTime := uint64(oldPol.LastUpdate.Unix()) + + var oldIndexKey [8 + 8]byte + byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) + byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + err = updateEdgePolicyDisabledIndex( + edges, edge.ChannelID, + edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, + edge.IsDisabled(), + ) + if err != nil { + return err + } + + return edges.Put(edgeKey[:], b.Bytes()) +} + +// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex +// bucket by either add a new disabled ChannelEdgePolicy1 or remove an existing +// one. +// The direction represents the direction of the edge and disabled is used for +// deciding whether to remove or add an entry to the bucket. +// In general a channel is disabled if two entries for the same chanID exist +// in this bucket. +// Maintaining the bucket this way allows a fast retrieval of disabled +// channels, for example when prune is needed. +func updateEdgePolicyDisabledIndex(edges kvdb.RwBucket, chanID uint64, + direction bool, disabled bool) error { + + var disabledEdgeKey [8 + 1]byte + byteOrder.PutUint64(disabledEdgeKey[0:], chanID) + if direction { + disabledEdgeKey[8] = 1 + } + + disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists( + disabledEdgePolicyBucket, + ) + if err != nil { + return err + } + + if disabled { + return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{}) + } + + return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:]) +} + +// putChanEdgePolicyUnknown marks the edge policy as unknown +// in the edges bucket. +func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64, + from []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], channelID) + + if edges.Get(edgeKey[:]) != nil { + return fmt.Errorf("cannot write unknown policy for channel %v "+ + " when there is already a policy present", channelID) + } + + return edges.Put(edgeKey[:], unknownPolicy) +} + +func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, + nodePub []byte) (*models.ChannelEdgePolicy1, error) { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], nodePub) + copy(edgeKey[33:], chanID) + + edgeBytes := edges.Get(edgeKey[:]) + if edgeBytes == nil { + return nil, ErrEdgeNotFound + } + + // No need to deserialize unknown policy. + if bytes.Equal(edgeBytes, unknownPolicy) { + return nil, nil + } + + edgeReader := bytes.NewReader(edgeBytes) + + ep, err := deserializeChanEdgePolicy(edgeReader) + switch { + // If the db policy was missing an expected optional field, we return + // nil as if the policy was unknown. + case errors.Is(err, ErrEdgePolicyOptionalFieldNotFound): + return nil, nil + + case err != nil: + return nil, err + } + + pol, ok := ep.(*models.ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected *models.ChannelEdgePolicy1, "+ + "got: %T", ep) + } + + return pol, nil +} + +func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, + chanID []byte) (*models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, + error) { + + edgeInfo := edgeIndex.Get(chanID) + if edgeInfo == nil { + return nil, nil, fmt.Errorf("%w: chanID=%x", ErrEdgeNotFound, + chanID) + } + + // The first node is contained within the first half of the edge + // information. We only propagate the error here and below if it's + // something other than edge non-existence. + node1Pub := edgeInfo[:33] + edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub) + if err != nil { + return nil, nil, fmt.Errorf("%w: node1Pub=%x", ErrEdgeNotFound, + node1Pub) + } + + // Similarly, the second node is contained within the latter + // half of the edge information. + node2Pub := edgeInfo[33:66] + edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub) + if err != nil { + return nil, nil, fmt.Errorf("%w: node2Pub=%x", ErrEdgeNotFound, + node2Pub) + } + + return edge1, edge2, nil +} + +func serializeChanEdgePolicy(w io.Writer, + edgePolicy models.ChannelEdgePolicy, toNode []byte) error { + + var ( + withTypeByte bool + typeByte edgePolicyEncodingType + serialize func(w io.Writer) error + ) + + switch policy := edgePolicy.(type) { + case *models.ChannelEdgePolicy1: + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy1(w, policy) + } + case *models.ChannelEdgePolicy2: + withTypeByte = true + typeByte = edgePolicy2EncodingType + + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy2(w, policy) + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgePolicy: %T", edgePolicy) + } + + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := w.Write([]byte{chanEdgePolicyNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = w.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + return serialize(w) +} + +func serializeChanEdgePolicy1(w io.Writer, + edge *models.ChannelEdgePolicy1) error { + + err := wire.WriteVarBytes(w, 0, edge.SigBytes) + if err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil { + return err + } + + var scratch [8]byte + updateUnix := uint64(edge.LastUpdate.Unix()) + byteOrder.PutUint64(scratch[:], updateUnix) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil { + return err + } + err = binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)) + if err != nil { + return err + } + err = binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)) + if err != nil { + return err + } + + if _, err := w.Write(edge.ToNode[:]); err != nil { + return err + } + + // If the max_htlc field is present, we write it. To be compatible with + // older versions that wasn't aware of this field, we write it as part + // of the opaque data. + // TODO(halseth): clean up when moving to TLV. + var opaqueBuf bytes.Buffer + if edge.MessageFlags.HasMaxHtlc() { + err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC)) + if err != nil { + return err + } + } + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil { + return err + } + + return nil +} + +func serializeChanEdgePolicy2(w io.Writer, + edge *models.ChannelEdgePolicy2) error { + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + var b bytes.Buffer + if err := edge.Encode(&b, 0); err != nil { + return err + } + + msg := b.Bytes() + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &edge.ToNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func deserializeChanEdgePolicy(r io.Reader) (models.ChannelEdgePolicy, + error) { + + // Deserialize the policy. Note that in case an optional field is not + // found, both an error and a populated policy object are returned. + edge, deserializeErr := deserializeChanEdgePolicyRaw(r) + if deserializeErr != nil && + !errors.Is(deserializeErr, ErrEdgePolicyOptionalFieldNotFound) { + + return nil, deserializeErr + } + + return edge, deserializeErr +} + +func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy, + error) { + + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgePolicyNewEncodingPrefix { + return deserializeChanEdgePolicy1Raw(r) + } + + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgePolicyEncodingType(scratch[0]) + switch encoding { + case edgePolicy2EncodingType: + return deserializeChanEdgePolicy2Raw(r) + + default: + return nil, fmt.Errorf("unknown edge policy encoding type: %d", + encoding) + } +} + +func deserializeChanEdgePolicy1Raw(r io.Reader) (*models.ChannelEdgePolicy1, + error) { + + edge := &models.ChannelEdgePolicy1{} + + var err error + edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") + if err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil { + return nil, err + } + + var scratch [8]byte + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + unix := int64(byteOrder.Uint64(scratch[:])) + edge.LastUpdate = time.Unix(unix, 0) + + if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil { + return nil, err + } + + var n uint64 + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.MinHTLC = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeBaseMSat = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n) + + if _, err := r.Read(edge.ToNode[:]); err != nil { + return nil, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edge.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case errors.Is(err, io.ErrUnexpectedEOF): + case errors.Is(err, io.EOF): + case err != nil: + return nil, err + } + + // See if optional fields are present. + if edge.MessageFlags.HasMaxHtlc() { + // The max_htlc field should be at the beginning of the opaque + // bytes. + opq := edge.ExtraOpaqueData + + // If the max_htlc field is not present, it might be old data + // stored before this field was validated. We'll return the + // edge along with an error. + if len(opq) < 8 { + return edge, ErrEdgePolicyOptionalFieldNotFound + } + + maxHtlc := byteOrder.Uint64(opq[:8]) + edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc) + + // Exclude the parsed field from the rest of the opaque data. + edge.ExtraOpaqueData = opq[8:] + } + + return edge, nil +} + +func deserializeChanEdgePolicy2Raw(r io.Reader) (*models.ChannelEdgePolicy2, + error) { + + var ( + msgBytes []byte + toNode [33]byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + err = stream.Decode(r) + if err != nil { + return nil, err + } + + var ( + chanUpdate lnwire.ChannelUpdate2 + reader = bytes.NewReader(msgBytes) + ) + err = chanUpdate.Decode(reader, 0) + if err != nil { + return nil, err + } + + return &models.ChannelEdgePolicy2{ + ChannelUpdate2: chanUpdate, + ToNode: toNode, + }, nil +} diff --git a/channeldb/edge_policy_test.go b/channeldb/edge_policy_test.go new file mode 100644 index 0000000000..6c8b1e2217 --- /dev/null +++ b/channeldb/edge_policy_test.go @@ -0,0 +1,171 @@ +package channeldb + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEdgePolicySerialisation tests the serialisation and deserialization logic +// for models.ChannelEdgePolicy. +func TestEdgePolicySerialisation(t *testing.T) { + t.Parallel() + + mainScenario := func(info models.ChannelEdgePolicy) bool { + var ( + b bytes.Buffer + toNode = info.GetToNode() + ) + + err := serializeChanEdgePolicy(&b, info, toNode[:]) + require.NoError(t, err) + + newInfo, err := deserializeChanEdgePolicy(&b) + require.NoError(t, err) + + return assert.Equal(t, info, newInfo) + } + + tests := []struct { + name string + genValue func([]reflect.Value, *rand.Rand) + scenario any + }{ + { + name: "ChannelEdgePolicy1", + scenario: func(m models.ChannelEdgePolicy1) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + //nolint:lll + policy := &models.ChannelEdgePolicy1{ + ChannelID: r.Uint64(), + LastUpdate: time.Unix(r.Int63(), 0), + MessageFlags: lnwire.ChanUpdateMsgFlags(r.Uint32()), + ChannelFlags: lnwire.ChanUpdateChanFlags(r.Uint32()), + TimeLockDelta: uint16(r.Uint32()), + MinHTLC: lnwire.MilliSatoshi(r.Uint64()), + FeeBaseMSat: lnwire.MilliSatoshi(r.Uint64()), + FeeProportionalMillionths: lnwire.MilliSatoshi(r.Uint64()), + ExtraOpaqueData: make([]byte, 0), + } + + policy.SigBytes = make([]byte, r.Intn(80)) + _, err := r.Read(policy.SigBytes) + require.NoError(t, err) + + _, err = r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + // Sometimes add an MaxHTLC. + if r.Intn(2)%2 == 0 { + policy.MessageFlags |= + lnwire.ChanUpdateRequiredMaxHtlc + policy.MaxHTLC = lnwire.MilliSatoshi( + r.Uint64(), + ) + } else { + policy.MessageFlags ^= + lnwire.ChanUpdateRequiredMaxHtlc + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + { + name: "ChannelEdgePolicy2", + scenario: func(m models.ChannelEdgePolicy2) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + policy := &models.ChannelEdgePolicy2{ + //nolint:lll + ChannelUpdate2: lnwire.ChannelUpdate2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + }, + ToNode: [33]byte{}, + } + + policy.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( //nolint:lll + uint64(r.Int63()), + ) + policy.BlockHeight.Val = r.Uint32() + policy.HTLCMaximumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.HTLCMinimumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.CLTVExpiryDelta.Val = uint16(r.Int31()) + policy.FeeBaseMsat.Val = r.Uint32() + policy.FeeProportionalMillionths.Val = r.Uint32() //nolint:lll + + if r.Intn(2) == 0 { + policy.Direction.Val.B = true + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableOutgoing + } + + _, err := r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + config := &quick.Config{ + Values: test.genValue, + } + + err := quick.Check(test.scenario, config) + require.NoError(t, err) + }) + } +} diff --git a/channeldb/graph.go b/channeldb/graph.go index 464398b40f..30ac93a3b2 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -231,8 +231,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, return nil, err } - err = g.ForEachChannel(func(info *models.ChannelEdgeInfo, - policy1, policy2 *models.ChannelEdgePolicy) error { + err = g.ForEachChannel(func(info *models.ChannelEdgeInfo1, + policy1, policy2 *models.ChannelEdgePolicy1) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -258,10 +258,10 @@ type channelMapKey struct { // getChannelMap loads all channel edge policies from the database and stores // them in a map. func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( - map[channelMapKey]*models.ChannelEdgePolicy, error) { + map[channelMapKey]*models.ChannelEdgePolicy1, error) { // Create a map to store all channel edge policies. - channelMap := make(map[channelMapKey]*models.ChannelEdgePolicy) + channelMap := make(map[channelMapKey]*models.ChannelEdgePolicy1) err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error { // Skip embedded buckets. @@ -303,7 +303,13 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } - channelMap[key] = edge + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", edge) + } + + channelMap[key] = e return nil }) @@ -410,8 +416,8 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1) error) error { return c.db.View(func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -444,16 +450,23 @@ func (c *ChannelGraph) ForEachChannel(cb func(*models.ChannelEdgeInfo, } policy1 := channelMap[channelMapKey{ - nodeKey: info.NodeKey1Bytes, + nodeKey: info.Node1Bytes(), chanID: chanID, }] policy2 := channelMap[channelMapKey{ - nodeKey: info.NodeKey2Bytes, + nodeKey: info.Node2Bytes(), chanID: chanID, }] - return cb(&info, policy1, policy2) + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + return cb(edgeInfo, policy1, policy2) }) }, func() {}) } @@ -480,8 +493,8 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, return err } - dbCallback := func(tx kvdb.RTx, e *models.ChannelEdgeInfo, p1, - p2 *models.ChannelEdgePolicy) error { + dbCallback := func(tx kvdb.RTx, e *models.ChannelEdgeInfo1, p1, + p2 *models.ChannelEdgePolicy1) error { var cachedInPolicy *models.CachedEdgePolicy if p2 != nil { @@ -566,9 +579,9 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, channels := make(map[uint64]*DirectedChannel) err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, - func(tx kvdb.RTx, e *models.ChannelEdgeInfo, - p1 *models.ChannelEdgePolicy, - p2 *models.ChannelEdgePolicy) error { + func(tx kvdb.RTx, e *models.ChannelEdgeInfo1, + p1 *models.ChannelEdgePolicy1, + p2 *models.ChannelEdgePolicy1) error { toNodeCallback := func() route.Vertex { return node.PubKeyBytes @@ -979,7 +992,7 @@ func (c *ChannelGraph) deleteLightningNode(nodes kvdb.RwBucket, // involved in creation of the channel, and the set of features that the channel // supports. The chanPoint and chanID are used to uniquely identify the edge // globally within the database. -func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo, +func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo1, op ...batch.SchedulerOption) error { var alreadyExists bool @@ -1023,7 +1036,7 @@ func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo, // addChannelEdge is the private form of AddChannelEdge that allows callers to // utilize an existing db transaction. func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, - edge *models.ChannelEdgeInfo) error { + edge *models.ChannelEdgeInfo1) error { // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte @@ -1095,7 +1108,7 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, // If the edge hasn't been created yet, then we'll first add it to the // edge index in order to associate the edge between two nodes and also // store the static components of the channel. - if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { + if err := putChanEdgeInfo(edgeIndex, edge); err != nil { return err } @@ -1235,7 +1248,7 @@ func (c *ChannelGraph) HasChannelEdge( // In order to maintain this constraints, we return an error in the scenario // that an edge info hasn't yet been created yet, but someone attempts to update // it. -func (c *ChannelGraph) UpdateChannelEdge(edge *models.ChannelEdgeInfo) error { +func (c *ChannelGraph) UpdateChannelEdge(edge *models.ChannelEdgeInfo1) error { // Construct the channel's primary key which is the 8-byte channel ID. var chanKey [8]byte binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) @@ -1259,7 +1272,7 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *models.ChannelEdgeInfo) error { c.graphCache.UpdateChannel(edge) } - return putChanEdgeInfo(edgeIndex, edge, chanKey) + return putChanEdgeInfo(edgeIndex, edge) }, func() {}) } @@ -1281,12 +1294,12 @@ const ( // the target block are returned if the function succeeds without error. func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, blockHash *chainhash.Hash, blockHeight uint32) ( - []*models.ChannelEdgeInfo, error) { + []*models.ChannelEdgeInfo1, error) { c.cacheMu.Lock() defer c.cacheMu.Unlock() - var chansClosed []*models.ChannelEdgeInfo + var chansClosed []*models.ChannelEdgeInfo1 err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { // First grab the edges bucket which houses the information @@ -1353,7 +1366,14 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - chansClosed = append(chansClosed, &edgeInfo) + info, ok := edgeInfo.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + chansClosed = append(chansClosed, info) } metaBucket, err := tx.CreateTopLevelBucket(graphMetaBucket) @@ -1472,17 +1492,17 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // edge info. We'll use this scan to populate our reference count map // above. err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - // The first 66 bytes of the edge info contain the pubkeys of - // the nodes that this edge attaches. We'll extract them, and - // add them to the ref count map. - var node1, node2 [33]byte - copy(node1[:], edgeInfoBytes[:33]) - copy(node2[:], edgeInfoBytes[33:]) + edge, err := deserializeChanEdgeInfo( + bytes.NewReader(edgeInfoBytes), + ) + if err != nil { + return err + } // With the nodes extracted, we'll increase the ref count of // each of the nodes. - nodeRefCounts[node1]++ - nodeRefCounts[node2]++ + nodeRefCounts[edge.Node1Bytes()]++ + nodeRefCounts[edge.Node2Bytes()]++ return nil }) @@ -1541,7 +1561,7 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, // Channels that were removed from the graph resulting from the // disconnected block are returned. func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( - []*models.ChannelEdgeInfo, error) { + []*models.ChannelEdgeInfo1, error) { // Every channel having a ShortChannelID starting at 'height' // will no longer be confirmed. @@ -1563,7 +1583,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( defer c.cacheMu.Unlock() // Keep track of the channels that are removed from the graph. - var removedChans []*models.ChannelEdgeInfo + var removedChans []*models.ChannelEdgeInfo1 if err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges, err := tx.CreateTopLevelBucket(edgeBucket) @@ -1602,7 +1622,15 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( } keys = append(keys, k) - removedChans = append(removedChans, &edgeInfo) + info, ok := edgeInfo.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + + keys = append(keys, k) + removedChans = append(removedChans, info) } for _, k := range keys { @@ -1866,15 +1894,15 @@ func (c *ChannelGraph) HighestChanID() (uint64, error) { // edge as well as each of the known advertised edge policies. type ChannelEdge struct { // Info contains all the static information describing the channel. - Info *models.ChannelEdgeInfo + Info *models.ChannelEdgeInfo1 // Policy1 points to the "first" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy1 *models.ChannelEdgePolicy + Policy1 *models.ChannelEdgePolicy1 // Policy2 points to the "second" edge policy of the channel containing // the dynamic information required to properly route through the edge. - Policy2 *models.ChannelEdgePolicy + Policy2 *models.ChannelEdgePolicy1 // Node1 is "node 1" in the channel. This is the node that would have // produced Policy1 if it exists. @@ -1958,13 +1986,20 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, } // First, we'll fetch the static edge information. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { chanID := byteOrder.Uint64(chanID) return fmt.Errorf("unable to fetch info for "+ "edge with chan_id=%v: %v", chanID, err) } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + // With the static information obtained, we'll now // fetch the dynamic policy info. edge1, edge2, err := fetchChanEdgePolicies( @@ -1995,7 +2030,7 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, // edges to be returned. edgesSeen[chanIDInt] = struct{}{} channel := ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, Node1: &node1, @@ -2313,7 +2348,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - if edgeInfo.AuthProof == nil { + if edgeInfo.GetAuthProof() == nil { continue } @@ -2335,7 +2370,7 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, continue } - node1Key, node2Key := computeEdgePolicyKeys(&edgeInfo) + node1Key, node2Key := computeEdgePolicyKeys(edgeInfo) rawPolicy := edges.Get(node1Key) if len(rawPolicy) != 0 { @@ -2349,7 +2384,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node1UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node1UpdateTimestamp = e.LastUpdate } rawPolicy = edges.Get(node2Key) @@ -2364,7 +2406,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node2UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node2UpdateTimestamp = e.LastUpdate } channelsPerBlock[cid.BlockHeight] = append( @@ -2453,7 +2502,7 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( // First, we'll fetch the static edge information. If // the edge is unknown, we will skip the edge and // continue gathering all known edges. - edgeInfo, err := fetchChanEdgeInfo( + info, err := fetchChanEdgeInfo( edgeIndex, cidBytes[:], ) switch { @@ -2472,6 +2521,13 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + info) + } + node1, err := fetchLightningNode( nodes, edgeInfo.NodeKey1Bytes[:], ) @@ -2487,7 +2543,7 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( } chanEdges = append(chanEdges, ChannelEdge{ - Info: &edgeInfo, + Info: edgeInfo, Policy1: edge1, Policy2: edge2, Node1: &node1, @@ -2517,7 +2573,7 @@ func (c *ChannelGraph) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( } func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, - edge1, edge2 *models.ChannelEdgePolicy) error { + edge1, edge2 *models.ChannelEdgePolicy1) error { // First, we'll fetch the edge update index bucket which currently // stores an entry for the channel we're about to delete. @@ -2565,11 +2621,17 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, zombieIndex kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected *models.ChannelEdgeInfo1, got %T", + info) + } + if c.graphCache != nil { c.graphCache.RemoveChannel( edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, @@ -2638,7 +2700,7 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, nodeKey1, nodeKey2 := edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes if strictZombie { - nodeKey1, nodeKey2 = makeZombiePubkeys(&edgeInfo, edge1, edge2) + nodeKey1, nodeKey2 = makeZombiePubkeys(edgeInfo, edge1, edge2) } return markEdgeZombie( @@ -2662,8 +2724,8 @@ func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, // the channel. If the channel were to be marked zombie again, it would be // marked with the correct lagging channel since we received an update from only // one side. -func makeZombiePubkeys(info *models.ChannelEdgeInfo, - e1, e2 *models.ChannelEdgePolicy) ([33]byte, [33]byte) { +func makeZombiePubkeys(info *models.ChannelEdgeInfo1, + e1, e2 *models.ChannelEdgePolicy1) ([33]byte, [33]byte) { switch { // If we don't have either edge policy, we'll return both pubkeys so @@ -2688,12 +2750,12 @@ func makeZombiePubkeys(info *models.ChannelEdgeInfo, // UpdateEdgePolicy updates the edge routing policy for a single directed edge // within the database for the referenced channel. The `flags` attribute within -// the ChannelEdgePolicy determines which of the directed edges are being +// the ChannelEdgePolicy1 determines which of the directed edges are being // updated. If the flag is 1, then the first node's information is being // updated, otherwise it's the second node's information. The node ordering is // determined by the lexicographical ordering of the identity public keys of the // nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, +func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy1, op ...batch.SchedulerOption) error { var ( @@ -2741,7 +2803,7 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, return c.chanScheduler.Execute(r) } -func (c *ChannelGraph) updateEdgeCache(e *models.ChannelEdgePolicy, +func (c *ChannelGraph) updateEdgeCache(e *models.ChannelEdgePolicy1, isUpdate1 bool) { // If an entry for this channel is found in reject cache, we'll modify @@ -2775,7 +2837,7 @@ func (c *ChannelGraph) updateEdgeCache(e *models.ChannelEdgePolicy, // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy, +func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy1, graphCache *GraphCache) (bool, error) { edges := tx.ReadWriteBucket(edgeBucket) @@ -2799,23 +2861,32 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy, return false, ErrEdgeNotFound } + edgeInfo, err := deserializeChanEdgeInfo(bytes.NewReader(nodeInfo)) + if err != nil { + return false, err + } + // Depending on the flags value passed above, either the first // or second edge policy is being updated. - var fromNode, toNode []byte - var isUpdate1 bool - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { - fromNode = nodeInfo[:33] - toNode = nodeInfo[33:66] + var ( + fromNode, toNode []byte + isUpdate1 bool + node1Bytes = edgeInfo.Node1Bytes() + node2Bytes = edgeInfo.Node2Bytes() + ) + if edge.IsNode1() { + fromNode = node1Bytes[:] + toNode = node2Bytes[:] isUpdate1 = true } else { - fromNode = nodeInfo[33:66] - toNode = nodeInfo[:33] + fromNode = node2Bytes[:] + toNode = node1Bytes[:] isUpdate1 = false } // Finally, with the direction of the edge being updated // identified, we update the on-disk edge representation. - err := putChanEdgePolicy(edges, edge, fromNode, toNode) + err = putChanEdgePolicy(edges, edge, fromNode, toNode) if err != nil { return false, err } @@ -2970,8 +3041,8 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, nodeIsPublic := false errDone := errors.New("done") err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, - info *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + info *models.ChannelEdgeInfo1, _ *models.ChannelEdgePolicy1, + _ *models.ChannelEdgePolicy1) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is @@ -3113,8 +3184,8 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) } @@ -3174,8 +3245,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error { traversal := func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -3209,11 +3280,18 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // the node at the other end of the channel and both // edge policies. chanID := nodeEdge[33:] - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + info, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { return err } + edgeInfo, ok := info.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edgeInfo) + } + outgoingPolicy, err := fetchChanEdgePolicy( edges, chanID, nodePub, ) @@ -3234,7 +3312,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, } // Finally, we execute the callback. - err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) + err = cb(tx, edgeInfo, outgoingPolicy, incomingPolicy) if err != nil { return err } @@ -3263,8 +3341,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // // Unknown policies are passed into the callback as nil values. func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error { return nodeTraversal(nil, nodePub[:], c.db, cb) } @@ -3283,9 +3361,9 @@ func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, // be nil and a fresh transaction will be created to execute the graph // traversal. func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, - nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { + nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error { return nodeTraversal(tx, nodePub[:], c.db, cb) } @@ -3295,7 +3373,7 @@ func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, - channel *models.ChannelEdgeInfo, thisNodeKey []byte) (*LightningNode, + channel *models.ChannelEdgeInfo1, thisNodeKey []byte) (*LightningNode, error) { // Ensure that the node passed in is actually a member of the channel. @@ -3343,17 +3421,20 @@ func (c *ChannelGraph) FetchOtherNode(tx kvdb.RTx, // computeEdgePolicyKeys is a helper function that can be used to compute the // keys used to index the channel edge policy info for the two nodes of the // edge. The keys for node 1 and node 2 are returned respectively. -func computeEdgePolicyKeys(info *models.ChannelEdgeInfo) ([]byte, []byte) { +func computeEdgePolicyKeys(info models.ChannelEdgeInfo) ([]byte, []byte) { var ( node1Key [33 + 8]byte node2Key [33 + 8]byte + + node1Bytes = info.Node1Bytes() + node2Bytes = info.Node2Bytes() ) - copy(node1Key[:], info.NodeKey1Bytes[:]) - copy(node2Key[:], info.NodeKey2Bytes[:]) + copy(node1Key[:], node1Bytes[:]) + copy(node2Key[:], node2Bytes[:]) - byteOrder.PutUint64(node1Key[33:], info.ChannelID) - byteOrder.PutUint64(node2Key[33:], info.ChannelID) + byteOrder.PutUint64(node1Key[33:], info.GetChanID()) + byteOrder.PutUint64(node2Key[33:], info.GetChanID()) return node1Key[:], node2Key[:] } @@ -3364,13 +3445,13 @@ func computeEdgePolicyKeys(info *models.ChannelEdgeInfo) ([]byte, []byte) { // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { var ( - edgeInfo *models.ChannelEdgeInfo - policy1 *models.ChannelEdgePolicy - policy2 *models.ChannelEdgePolicy + edgeInfo *models.ChannelEdgeInfo1 + policy1 *models.ChannelEdgePolicy1 + policy2 *models.ChannelEdgePolicy1 ) err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -3414,7 +3495,15 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( if err != nil { return fmt.Errorf("%w: chanID=%x", err, chanID) } - edgeInfo = &edge + + info, ok := edge.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", + edge) + } + + edgeInfo = info // Once we have the information about the channels' parameters, // we'll fetch the routing policies for each for the directed @@ -3446,16 +3535,16 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( // routing policies for the channel in either direction. // // ErrZombieEdge an be returned if the edge is currently marked as a zombie -// within the database. In this case, the ChannelEdgePolicy's will be nil, and -// the ChannelEdgeInfo will only include the public keys of each node. +// within the database. In this case, the ChannelEdgePolicy1's will be nil, and +// the ChannelEdgeInfo1 will only include the public keys of each node. func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { var ( - edgeInfo *models.ChannelEdgeInfo - policy1 *models.ChannelEdgePolicy - policy2 *models.ChannelEdgePolicy + edgeInfo *models.ChannelEdgeInfo1 + policy1 *models.ChannelEdgePolicy1 + policy2 *models.ChannelEdgePolicy1 channelID [8]byte ) @@ -3506,7 +3595,7 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( // populate the edge info with the public keys of each // party as this is the only information we have about // it and return an error signaling so. - edgeInfo = &models.ChannelEdgeInfo{ + edgeInfo = &models.ChannelEdgeInfo1{ NodeKey1Bytes: pubKey1, NodeKey2Bytes: pubKey2, } @@ -3518,7 +3607,13 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( return err } - edgeInfo = &edge + info, ok := edge.(*models.ChannelEdgeInfo1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgeInfo1, got %T", edge) + } + + edgeInfo = info // Then we'll attempt to fetch the accompanying policies of this // edge. @@ -3658,10 +3753,7 @@ func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { return err } - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) + pkScript, err := edgeInfo.FundingScript() if err != nil { return err } @@ -4176,512 +4268,3 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } - -func putChanEdgeInfo(edgeIndex kvdb.RwBucket, - edgeInfo *models.ChannelEdgeInfo, chanID [8]byte) error { - - var b bytes.Buffer - - if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return err - } - - if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { - return err - } - - authProof := edgeInfo.AuthProof - var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte - if authProof != nil { - nodeSig1 = authProof.NodeSig1Bytes - nodeSig2 = authProof.NodeSig2Bytes - bitcoinSig1 = authProof.BitcoinSig1Bytes - bitcoinSig2 = authProof.BitcoinSig2Bytes - } - - if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { - return err - } - - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { - return err - } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { - return err - } - if _, err := b.Write(chanID[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { - return err - } - - if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { - return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) - } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) - if err != nil { - return err - } - - return edgeIndex.Put(chanID[:], b.Bytes()) -} - -func fetchChanEdgeInfo(edgeIndex kvdb.RBucket, - chanID []byte) (models.ChannelEdgeInfo, error) { - - edgeInfoBytes := edgeIndex.Get(chanID) - if edgeInfoBytes == nil { - return models.ChannelEdgeInfo{}, ErrEdgeNotFound - } - - edgeInfoReader := bytes.NewReader(edgeInfoBytes) - return deserializeChanEdgeInfo(edgeInfoReader) -} - -func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo, error) { - var ( - err error - edgeInfo models.ChannelEdgeInfo - ) - - if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { - return models.ChannelEdgeInfo{}, err - } - if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { - return models.ChannelEdgeInfo{}, err - } - if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return models.ChannelEdgeInfo{}, err - } - if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return models.ChannelEdgeInfo{}, err - } - - edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") - if err != nil { - return models.ChannelEdgeInfo{}, err - } - - proof := &models.ChannelAuthProof{} - - proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") - if err != nil { - return models.ChannelEdgeInfo{}, err - } - proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") - if err != nil { - return models.ChannelEdgeInfo{}, err - } - proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") - if err != nil { - return models.ChannelEdgeInfo{}, err - } - proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") - if err != nil { - return models.ChannelEdgeInfo{}, err - } - - if !proof.IsEmpty() { - edgeInfo.AuthProof = proof - } - - edgeInfo.ChannelPoint = wire.OutPoint{} - if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { - return models.ChannelEdgeInfo{}, err - } - if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { - return models.ChannelEdgeInfo{}, err - } - if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { - return models.ChannelEdgeInfo{}, err - } - - if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { - return models.ChannelEdgeInfo{}, err - } - - // We'll try and see if there are any opaque bytes left, if not, then - // we'll ignore the EOF error and return the edge as is. - edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes( - r, 0, MaxAllowedExtraOpaqueBytes, "blob", - ) - switch { - case err == io.ErrUnexpectedEOF: - case err == io.EOF: - case err != nil: - return models.ChannelEdgeInfo{}, err - } - - return edgeInfo, nil -} - -func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy, - from, to []byte) error { - - var edgeKey [33 + 8]byte - copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) - - var b bytes.Buffer - if err := serializeChanEdgePolicy(&b, edge, to); err != nil { - return err - } - - // Before we write out the new edge, we'll create a new entry in the - // update index in order to keep it fresh. - updateUnix := uint64(edge.LastUpdate.Unix()) - var indexKey [8 + 8]byte - byteOrder.PutUint64(indexKey[:8], updateUnix) - byteOrder.PutUint64(indexKey[8:], edge.ChannelID) - - updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) - if err != nil { - return err - } - - // If there was already an entry for this edge, then we'll need to - // delete the old one to ensure we don't leave around any after-images. - // An unknown policy value does not have a update time recorded, so - // it also does not need to be removed. - if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil && - !bytes.Equal(edgeBytes[:], unknownPolicy) { - - // In order to delete the old entry, we'll need to obtain the - // *prior* update time in order to delete it. To do this, we'll - // need to deserialize the existing policy within the database - // (now outdated by the new one), and delete its corresponding - // entry within the update index. We'll ignore any - // ErrEdgePolicyOptionalFieldNotFound error, as we only need - // the channel ID and update time to delete the entry. - // TODO(halseth): get rid of these invalid policies in a - // migration. - oldEdgePolicy, err := deserializeChanEdgePolicy( - bytes.NewReader(edgeBytes), - ) - if err != nil && err != ErrEdgePolicyOptionalFieldNotFound { - return err - } - - oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) - - var oldIndexKey [8 + 8]byte - byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) - byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) - - if err := updateIndex.Delete(oldIndexKey[:]); err != nil { - return err - } - } - - if err := updateIndex.Put(indexKey[:], nil); err != nil { - return err - } - - updateEdgePolicyDisabledIndex( - edges, edge.ChannelID, - edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, - edge.IsDisabled(), - ) - - return edges.Put(edgeKey[:], b.Bytes()[:]) -} - -// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex -// bucket by either add a new disabled ChannelEdgePolicy or remove an existing -// one. -// The direction represents the direction of the edge and disabled is used for -// deciding whether to remove or add an entry to the bucket. -// In general a channel is disabled if two entries for the same chanID exist -// in this bucket. -// Maintaining the bucket this way allows a fast retrieval of disabled -// channels, for example when prune is needed. -func updateEdgePolicyDisabledIndex(edges kvdb.RwBucket, chanID uint64, - direction bool, disabled bool) error { - - var disabledEdgeKey [8 + 1]byte - byteOrder.PutUint64(disabledEdgeKey[0:], chanID) - if direction { - disabledEdgeKey[8] = 1 - } - - disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists( - disabledEdgePolicyBucket, - ) - if err != nil { - return err - } - - if disabled { - return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{}) - } - - return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:]) -} - -// putChanEdgePolicyUnknown marks the edge policy as unknown -// in the edges bucket. -func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64, - from []byte) error { - - var edgeKey [33 + 8]byte - copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], channelID) - - if edges.Get(edgeKey[:]) != nil { - return fmt.Errorf("cannot write unknown policy for channel %v "+ - " when there is already a policy present", channelID) - } - - return edges.Put(edgeKey[:], unknownPolicy) -} - -func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, - nodePub []byte) (*models.ChannelEdgePolicy, error) { - - var edgeKey [33 + 8]byte - copy(edgeKey[:], nodePub) - copy(edgeKey[33:], chanID[:]) - - edgeBytes := edges.Get(edgeKey[:]) - if edgeBytes == nil { - return nil, ErrEdgeNotFound - } - - // No need to deserialize unknown policy. - if bytes.Equal(edgeBytes[:], unknownPolicy) { - return nil, nil - } - - edgeReader := bytes.NewReader(edgeBytes) - - ep, err := deserializeChanEdgePolicy(edgeReader) - switch { - // If the db policy was missing an expected optional field, we return - // nil as if the policy was unknown. - case err == ErrEdgePolicyOptionalFieldNotFound: - return nil, nil - - case err != nil: - return nil, err - } - - return ep, nil -} - -func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, - chanID []byte) (*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, - error) { - - edgeInfo := edgeIndex.Get(chanID) - if edgeInfo == nil { - return nil, nil, fmt.Errorf("%w: chanID=%x", ErrEdgeNotFound, - chanID) - } - - // The first node is contained within the first half of the edge - // information. We only propagate the error here and below if it's - // something other than edge non-existence. - node1Pub := edgeInfo[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub) - if err != nil { - return nil, nil, fmt.Errorf("%w: node1Pub=%x", ErrEdgeNotFound, - node1Pub) - } - - // Similarly, the second node is contained within the latter - // half of the edge information. - node2Pub := edgeInfo[33:66] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub) - if err != nil { - return nil, nil, fmt.Errorf("%w: node2Pub=%x", ErrEdgeNotFound, - node2Pub) - } - - return edge1, edge2, nil -} - -func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy, - to []byte) error { - - err := wire.WriteVarBytes(w, 0, edge.SigBytes) - if err != nil { - return err - } - - if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil { - return err - } - - var scratch [8]byte - updateUnix := uint64(edge.LastUpdate.Unix()) - byteOrder.PutUint64(scratch[:], updateUnix) - if _, err := w.Write(scratch[:]); err != nil { - return err - } - - if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil { - return err - } - if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil { - return err - } - if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil { - return err - } - if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil { - return err - } - if err := binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)); err != nil { - return err - } - if err := binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)); err != nil { - return err - } - - if _, err := w.Write(to); err != nil { - return err - } - - // If the max_htlc field is present, we write it. To be compatible with - // older versions that wasn't aware of this field, we write it as part - // of the opaque data. - // TODO(halseth): clean up when moving to TLV. - var opaqueBuf bytes.Buffer - if edge.MessageFlags.HasMaxHtlc() { - err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC)) - if err != nil { - return err - } - } - - if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { - return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) - } - if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil { - return err - } - - if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil { - return err - } - return nil -} - -func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy, error) { - // Deserialize the policy. Note that in case an optional field is not - // found, both an error and a populated policy object are returned. - edge, deserializeErr := deserializeChanEdgePolicyRaw(r) - if deserializeErr != nil && - deserializeErr != ErrEdgePolicyOptionalFieldNotFound { - - return nil, deserializeErr - } - - return edge, deserializeErr -} - -func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, - error) { - - edge := &models.ChannelEdgePolicy{} - - var err error - edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") - if err != nil { - return nil, err - } - - if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil { - return nil, err - } - - var scratch [8]byte - if _, err := r.Read(scratch[:]); err != nil { - return nil, err - } - unix := int64(byteOrder.Uint64(scratch[:])) - edge.LastUpdate = time.Unix(unix, 0) - - if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil { - return nil, err - } - - var n uint64 - if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err - } - edge.MinHTLC = lnwire.MilliSatoshi(n) - - if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err - } - edge.FeeBaseMSat = lnwire.MilliSatoshi(n) - - if err := binary.Read(r, byteOrder, &n); err != nil { - return nil, err - } - edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n) - - if _, err := r.Read(edge.ToNode[:]); err != nil { - return nil, err - } - - // We'll try and see if there are any opaque bytes left, if not, then - // we'll ignore the EOF error and return the edge as is. - edge.ExtraOpaqueData, err = wire.ReadVarBytes( - r, 0, MaxAllowedExtraOpaqueBytes, "blob", - ) - switch { - case err == io.ErrUnexpectedEOF: - case err == io.EOF: - case err != nil: - return nil, err - } - - // See if optional fields are present. - if edge.MessageFlags.HasMaxHtlc() { - // The max_htlc field should be at the beginning of the opaque - // bytes. - opq := edge.ExtraOpaqueData - - // If the max_htlc field is not present, it might be old data - // stored before this field was validated. We'll return the - // edge along with an error. - if len(opq) < 8 { - return edge, ErrEdgePolicyOptionalFieldNotFound - } - - maxHtlc := byteOrder.Uint64(opq[:8]) - edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc) - - // Exclude the parsed field from the rest of the opaque data. - edge.ExtraOpaqueData = opq[8:] - } - - return edge, nil -} diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 9bd2a82658..7df72b209e 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -28,9 +28,9 @@ type GraphCacheNode interface { // error, then the iteration is halted with the error propagated back up // to the caller. ForEachChannel(kvdb.RTx, - func(kvdb.RTx, *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error + func(kvdb.RTx, *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error } // DirectedChannel is a type that stores the channel information as seen from @@ -142,9 +142,9 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { c.AddNodeFeatures(node) return node.ForEachChannel( - tx, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, - outPolicy *models.ChannelEdgePolicy, - inPolicy *models.ChannelEdgePolicy) error { + tx, func(tx kvdb.RTx, info *models.ChannelEdgeInfo1, + outPolicy *models.ChannelEdgePolicy1, + inPolicy *models.ChannelEdgePolicy1) error { c.AddChannel(info, outPolicy, inPolicy) @@ -157,8 +157,8 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { // and policy 2 does not matter, the directionality is extracted from the info // and policy flags automatically. The policy will be set as the outgoing policy // on one node and the incoming policy on the peer's side. -func (c *GraphCache) AddChannel(info *models.ChannelEdgeInfo, - policy1 *models.ChannelEdgePolicy, policy2 *models.ChannelEdgePolicy) { +func (c *GraphCache) AddChannel(info *models.ChannelEdgeInfo1, policy1, + policy2 *models.ChannelEdgePolicy1) { if info == nil { return @@ -220,7 +220,7 @@ func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { // of the from and to node is not strictly important. But we assume that a // channel edge was added beforehand so that the directed channel struct already // exists in the cache. -func (c *GraphCache) UpdatePolicy(policy *models.ChannelEdgePolicy, fromNode, +func (c *GraphCache) UpdatePolicy(policy *models.ChannelEdgePolicy1, fromNode, toNode route.Vertex, edge1 bool) { // Extract inbound fee if possible and available. If there is a decoding @@ -309,7 +309,7 @@ func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { // UpdateChannel updates the channel edge information for a specific edge. We // expect the edge to already exist and be known. If it does not yet exist, this // call is a no-op. -func (c *GraphCache) UpdateChannel(info *models.ChannelEdgeInfo) { +func (c *GraphCache) UpdateChannel(info *models.ChannelEdgeInfo1) { c.mtx.Lock() defer c.mtx.Unlock() diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index f7ed5cee60..65e639be55 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -29,9 +29,9 @@ type node struct { pubKey route.Vertex features *lnwire.FeatureVector - edgeInfos []*models.ChannelEdgeInfo - outPolicies []*models.ChannelEdgePolicy - inPolicies []*models.ChannelEdgePolicy + edgeInfos []*models.ChannelEdgeInfo1 + outPolicies []*models.ChannelEdgePolicy1 + inPolicies []*models.ChannelEdgePolicy1 } func (n *node) PubKey() route.Vertex { @@ -42,8 +42,8 @@ func (n *node) Features() *lnwire.FeatureVector { } func (n *node) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error { for idx := range n.edgeInfos { err := cb( @@ -71,7 +71,7 @@ func TestGraphCacheAddNode(t *testing.T) { channelFlagA, channelFlagB = 1, 0 } - outPolicy1 := &models.ChannelEdgePolicy{ + outPolicy1 := &models.ChannelEdgePolicy1{ ChannelID: 1000, ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), ToNode: nodeB, @@ -80,7 +80,7 @@ func TestGraphCacheAddNode(t *testing.T) { 253, 217, 3, 8, 0, 0, 0, 10, 0, 0, 0, 20, }, } - inPolicy1 := &models.ChannelEdgePolicy{ + inPolicy1 := &models.ChannelEdgePolicy1{ ChannelID: 1000, ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), ToNode: nodeA, @@ -88,15 +88,15 @@ func TestGraphCacheAddNode(t *testing.T) { node := &node{ pubKey: nodeA, features: lnwire.EmptyFeatureVector(), - edgeInfos: []*models.ChannelEdgeInfo{{ + edgeInfos: []*models.ChannelEdgeInfo1{{ ChannelID: 1000, // Those are direction independent! NodeKey1Bytes: pubKey1, NodeKey2Bytes: pubKey2, Capacity: 500, }}, - outPolicies: []*models.ChannelEdgePolicy{outPolicy1}, - inPolicies: []*models.ChannelEdgePolicy{inPolicy1}, + outPolicies: []*models.ChannelEdgePolicy1{outPolicy1}, + inPolicies: []*models.ChannelEdgePolicy1{inPolicy1}, } cache := NewGraphCache(10) require.NoError(t, cache.AddNode(nil, node)) @@ -153,7 +153,7 @@ func TestGraphCacheAddNode(t *testing.T) { runTest(pubKey2, pubKey1) } -func assertCachedPolicyEqual(t *testing.T, original *models.ChannelEdgePolicy, +func assertCachedPolicyEqual(t *testing.T, original *models.ChannelEdgePolicy1, cached *models.CachedEdgePolicy) { require.Equal(t, original.ChannelID, cached.ChannelID) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b05f3daaab..ecc6ffc879 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -326,10 +326,10 @@ func TestEdgeInsertionDeletion(t *testing.T) { require.NoError(t, err, "unable to generate node key") node2Pub, err := node2.PubKey() require.NoError(t, err, "unable to generate node key") - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -386,7 +386,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (models.ChannelEdgeInfo, + node1, node2 *LightningNode) (models.ChannelEdgeInfo1, lnwire.ShortChannelID) { shortChanID := lnwire.ShortChannelID{ @@ -401,10 +401,10 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, node1Pub, _ := node1.PubKey() node2Pub, _ := node2.PubKey() - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: shortChanID.ToUint64(), ChainHash: key, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -556,8 +556,8 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } } -func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, - e2 *models.ChannelEdgeInfo) { +func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo1, + e2 *models.ChannelEdgeInfo1) { if e1.ChannelID != e2.ChannelID { t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, @@ -619,8 +619,8 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, } func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) { var ( firstNode [33]byte @@ -644,10 +644,10 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( // Add the new edge to the database, this should proceed without any // errors. - edgeInfo := &models.ChannelEdgeInfo{ + edgeInfo := &models.ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -662,7 +662,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( copy(edgeInfo.BitcoinKey1Bytes[:], firstNode[:]) copy(edgeInfo.BitcoinKey2Bytes[:], secondNode[:]) - edge1 := &models.ChannelEdgePolicy{ + edge1 := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: time.Unix(433453, 0), @@ -676,7 +676,7 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) ( ToNode: secondNode, ExtraOpaqueData: []byte{1, 0}, } - edge2 := &models.ChannelEdgePolicy{ + edge2 := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID, LastUpdate: time.Unix(124234, 0), @@ -821,7 +821,7 @@ func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { } func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, - e *models.ChannelEdgeInfo) { + e *models.ChannelEdgeInfo1) { // Let's check the internal view first. require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) @@ -899,7 +899,8 @@ func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { } func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, - e *models.ChannelEdgeInfo, p *models.ChannelEdgePolicy, policy1 bool) { + e *models.ChannelEdgeInfo1, p *models.ChannelEdgePolicy1, + policy1 bool) { // Check the internal state first. c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] @@ -975,16 +976,16 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, } } -func randEdgePolicy(chanID uint64, db kvdb.Backend) *models.ChannelEdgePolicy { +func randEdgePolicy(chanID uint64, db kvdb.Backend) *models.ChannelEdgePolicy1 { update := prand.Int63() return newEdgePolicy(chanID, db, update) } func newEdgePolicy(chanID uint64, db kvdb.Backend, - updateTime int64) *models.ChannelEdgePolicy { + updateTime int64) *models.ChannelEdgePolicy1 { - return &models.ChannelEdgePolicy{ + return &models.ChannelEdgePolicy1{ ChannelID: chanID, LastUpdate: time.Unix(updateTime, 0), MessageFlags: 1, @@ -1041,9 +1042,9 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(func(ei *models.ChannelEdgeInfo, - _ *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + err = graph.ForEachChannel(func(ei *models.ChannelEdgeInfo1, + _ *models.ChannelEdgePolicy1, + _ *models.ChannelEdgePolicy1) error { delete(chanIndex, ei.ChannelID) return nil @@ -1056,8 +1057,8 @@ func TestGraphTraversal(t *testing.T) { numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, - func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, - inEdge *models.ChannelEdgePolicy) error { + func(_ kvdb.RTx, _ *models.ChannelEdgeInfo1, outEdge, + inEdge *models.ChannelEdgePolicy1) error { // All channels between first and second node should // have fully (both sides) specified policies. @@ -1137,9 +1138,9 @@ func TestGraphTraversalCacheable(t *testing.T) { for _, node := range nodes { err := node.ForEachChannel( tx, func(tx kvdb.RTx, - info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy, - policy2 *models.ChannelEdgePolicy) error { //nolint:lll + info *models.ChannelEdgeInfo1, + policy *models.ChannelEdgePolicy1, + policy2 *models.ChannelEdgePolicy1) error { //nolint:lll delete(chanIndex, info.ChannelID) return nil @@ -1257,10 +1258,10 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, Index: 0, } - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1321,9 +1322,9 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 - if err := graph.ForEachChannel(func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error { + if err := graph.ForEachChannel(func(*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error { numChans++ return nil @@ -1439,10 +1440,10 @@ func TestGraphPruning(t *testing.T) { channelPoints = append(channelPoints, &op) - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: chanID, ChainHash: key, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -2120,7 +2121,7 @@ func TestStressTestChannelGraphAPI(t *testing.T) { require.NoError(t, err) type chanInfo struct { - info models.ChannelEdgeInfo + info models.ChannelEdgeInfo1 id lnwire.ShortChannelID } @@ -2439,7 +2440,7 @@ func TestFilterChannelRange(t *testing.T) { var updateTime = time.Unix(0, 0) if rand.Int31n(2) == 0 { updateTime = time.Unix(updateTimeSeed, 0) - err = graph.UpdateEdgePolicy(&models.ChannelEdgePolicy{ + err = graph.UpdateEdgePolicy(&models.ChannelEdgePolicy1{ ToNode: node.PubKeyBytes, ChannelFlags: chanFlags, ChannelID: chanID, @@ -2748,8 +2749,8 @@ func TestIncompleteChannelPolicies(t *testing.T) { checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 err := graph.ForEachNodeChannel(node.PubKeyBytes, - func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, - inEdge *models.ChannelEdgePolicy) error { + func(_ kvdb.RTx, _ *models.ChannelEdgeInfo1, outEdge, + inEdge *models.ChannelEdgePolicy1) error { if !expectedOut && outEdge != nil { t.Fatalf("Expected no outgoing policy") @@ -3166,7 +3167,7 @@ func TestNodeIsPublic(t *testing.T) { // After creating all of our nodes and edges, we'll add them to each // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} - edges := []*models.ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} + edges := []*models.ChannelEdgeInfo1{&aliceBobEdge, &bobCarolEdge} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for _, graph := range graphs { for _, node := range nodes { @@ -3342,7 +3343,7 @@ func TestDisabledChannelIDs(t *testing.T) { } } -// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in +// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy1 in // the DB that indicates that it should support the htlc_maximum_value_msat // field, but it is not part of the opaque data, then we'll handle it as it is // unknown. It also checks that we are correctly able to overwrite it when we @@ -3609,7 +3610,7 @@ func compareNodes(a, b *LightningNode) error { // compareEdgePolicies is used to compare two ChannelEdgePolices using // compareNodes, so as to exclude comparisons of the Nodes' Features struct. -func compareEdgePolicies(a, b *models.ChannelEdgePolicy) error { +func compareEdgePolicies(a, b *models.ChannelEdgePolicy1) error { if a.ChannelID != b.ChannelID { return fmt.Errorf("ChannelID doesn't match: expected %v, "+ "got %v", a.ChannelID, b.ChannelID) @@ -3701,7 +3702,7 @@ func TestLightningNodeSigVerification(t *testing.T) { // TestComputeFee tests fee calculation based on the outgoing amt. func TestComputeFee(t *testing.T) { var ( - policy = models.ChannelEdgePolicy{ + policy = models.ChannelEdgePolicy1{ FeeBaseMSat: 10000, FeeProportionalMillionths: 30000, } @@ -3770,7 +3771,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { // Create a third edge, this with a block height of 155. edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) - edges := []models.ChannelEdgeInfo{edgeInfo, edgeInfo2, edgeInfo3} + edges := []models.ChannelEdgeInfo1{edgeInfo, edgeInfo2, edgeInfo3} errChan := make(chan error, len(edges)) errTimeout := errors.New("timeout adding batched channel") @@ -3778,7 +3779,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { var wg sync.WaitGroup for _, edge := range edges { wg.Add(1) - go func(edge models.ChannelEdgeInfo) { + go func(edge models.ChannelEdgeInfo1) { defer wg.Done() select { @@ -3829,7 +3830,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { errTimeout := errors.New("timeout adding batched channel") - updates := []*models.ChannelEdgePolicy{edge1, edge2} + updates := []*models.ChannelEdgePolicy1{edge1, edge2} errChan := make(chan error, len(updates)) @@ -3837,7 +3838,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { var wg sync.WaitGroup for _, update := range updates { wg.Add(1) - go func(update *models.ChannelEdgePolicy) { + go func(update *models.ChannelEdgePolicy1) { defer wg.Done() select { @@ -3886,9 +3887,9 @@ func BenchmarkForEachChannel(b *testing.B) { err = graph.db.View(func(tx kvdb.RTx) error { for _, n := range nodes { cb := func(tx kvdb.RTx, - info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy, - policy2 *models.ChannelEdgePolicy) error { //nolint:lll + info *models.ChannelEdgeInfo1, + policy *models.ChannelEdgePolicy1, + policy2 *models.ChannelEdgePolicy1) error { //nolint:lll // We need to do something with // the data here, otherwise the @@ -3939,7 +3940,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { // Because of lexigraphical sorting and the usage of random node keys in // this test, we need to determine which edge belongs to node 1 at // runtime. - var edge1 *models.ChannelEdgePolicy + var edge1 *models.ChannelEdgePolicy1 if e1.ToNode == node2.PubKeyBytes { edge1 = e1 } else { diff --git a/channeldb/models/cached_edge_policy.go b/channeldb/models/cached_edge_policy.go index b770ec1fbe..89f9a98a0b 100644 --- a/channeldb/models/cached_edge_policy.go +++ b/channeldb/models/cached_edge_policy.go @@ -11,7 +11,7 @@ const ( ) // CachedEdgePolicy is a struct that only caches the information of a -// ChannelEdgePolicy that we actually use for pathfinding and therefore need to +// ChannelEdgePolicy1 that we actually use for pathfinding and therefore need to // store in the cache. type CachedEdgePolicy struct { // ChannelID is the unique channel ID for the channel. The first 3 @@ -72,7 +72,7 @@ func (c *CachedEdgePolicy) ComputeFee( } // NewCachedPolicy turns a full policy into a minimal one that can be cached. -func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { +func NewCachedPolicy(policy *ChannelEdgePolicy1) *CachedEdgePolicy { return &CachedEdgePolicy{ ChannelID: policy.ChannelID, MessageFlags: policy.MessageFlags, diff --git a/channeldb/models/channel_auth_proof.go b/channeldb/models/channel_auth_proof.go index 1341394674..c4b1b4c2e3 100644 --- a/channeldb/models/channel_auth_proof.go +++ b/channeldb/models/channel_auth_proof.go @@ -2,14 +2,14 @@ package models import "github.com/btcsuite/btcd/btcec/v2/ecdsa" -// ChannelAuthProof is the authentication proof (the signature portion) for a +// ChannelAuthProof1 is the authentication proof (the signature portion) for a // channel. Using the four signatures contained in the struct, and some // auxiliary knowledge (the funding script, node identities, and outpoint) nodes // on the network are able to validate the authenticity and existence of a // channel. Each of these signatures signs the following digest: chanID || // nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || // features. -type ChannelAuthProof struct { +type ChannelAuthProof1 struct { // nodeSig1 is a cached instance of the first node signature. nodeSig1 *ecdsa.Signature @@ -45,7 +45,7 @@ type ChannelAuthProof struct { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node1Sig() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) Node1Sig() (*ecdsa.Signature, error) { if c.nodeSig1 != nil { return c.nodeSig1, nil } @@ -66,7 +66,7 @@ func (c *ChannelAuthProof) Node1Sig() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node2Sig() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) Node2Sig() (*ecdsa.Signature, error) { if c.nodeSig2 != nil { return c.nodeSig2, nil } @@ -86,7 +86,7 @@ func (c *ChannelAuthProof) Node2Sig() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig1() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) BitcoinSig1() (*ecdsa.Signature, error) { if c.bitcoinSig1 != nil { return c.bitcoinSig1, nil } @@ -106,7 +106,7 @@ func (c *ChannelAuthProof) BitcoinSig1() (*ecdsa.Signature, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig2() (*ecdsa.Signature, error) { +func (c *ChannelAuthProof1) BitcoinSig2() (*ecdsa.Signature, error) { if c.bitcoinSig2 != nil { return c.bitcoinSig2, nil } @@ -123,9 +123,36 @@ func (c *ChannelAuthProof) BitcoinSig2() (*ecdsa.Signature, error) { // IsEmpty check is the authentication proof is empty Proof is empty if at // least one of the signatures are equal to nil. -func (c *ChannelAuthProof) IsEmpty() bool { +func (c *ChannelAuthProof1) IsEmpty() bool { return len(c.NodeSig1Bytes) == 0 || len(c.NodeSig2Bytes) == 0 || len(c.BitcoinSig1Bytes) == 0 || len(c.BitcoinSig2Bytes) == 0 } + +// isChanAuthProof is a no-op method used to ensure that a struct must +// explicitly inherit this interface to be considered a ChannelAuthProof type. +// +// NOTE: this is part of the ChannelAuthProof interface. +func (c *ChannelAuthProof1) isChanAuthProof() {} + +// A compile-time check to ensure that ChannelAutoProof1 implements the +// ChannelAuthProof interface. +var _ ChannelAuthProof = (*ChannelAuthProof1)(nil) + +// ChannelAuthProof2 is the authentication proof required for a taproot channel +// announcement. It contains a single Schnorr signature. +type ChannelAuthProof2 struct { + // SchnorrSigBytes are the raw bytes of the encoded schnorr signature. + SchnorrSigBytes []byte +} + +// isChanAuthProof is a no-op method used to ensure that a struct must +// explicitly inherit this interface to be considered a ChannelAuthProof type. +// +// NOTE: this is part of the ChannelAuthProof interface. +func (c *ChannelAuthProof2) isChanAuthProof() {} + +// A compile-time check to ensure that ChannelAutoProof2 implements the +// ChannelAuthProof interface. +var _ ChannelAuthProof = (*ChannelAuthProof2)(nil) diff --git a/channeldb/models/channel_edge_info.go b/channeldb/models/channel_edge_info.go index 1afa2d6272..d5645bf13d 100644 --- a/channeldb/models/channel_edge_info.go +++ b/channeldb/models/channel_edge_info.go @@ -5,18 +5,23 @@ import ( "fmt" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) -// ChannelEdgeInfo represents a fully authenticated channel along with all its +// ChannelEdgeInfo1 represents a fully authenticated channel along with all its // unique attributes. Once an authenticated channel announcement has been -// processed on the network, then an instance of ChannelEdgeInfo encapsulating +// processed on the network, then an instance of ChannelEdgeInfo1 encapsulating // the channels attributes is stored. The other portions relevant to routing -// policy of a channel are stored within a ChannelEdgePolicy for each direction +// policy of a channel are stored within a ChannelEdgePolicy1 for each direction // of the channel. -type ChannelEdgeInfo struct { +type ChannelEdgeInfo1 struct { // ChannelID is the unique channel ID for the channel. The first 3 // bytes are the block height, the next 3 the index within the block, // and the last 2 bytes are the output index for the channel. @@ -52,7 +57,7 @@ type ChannelEdgeInfo struct { // AuthProof is the authentication proof for this channel. This proof // contains a set of signatures binding four identities, which attests // to the legitimacy of the advertised channel. - AuthProof *ChannelAuthProof + AuthProof *ChannelAuthProof1 // ChannelPoint is the funding outpoint of the channel. This can be // used to uniquely identify the channel within the channel graph. @@ -72,8 +77,8 @@ type ChannelEdgeInfo struct { } // AddNodeKeys is a setter-like method that can be used to replace the set of -// keys for the target ChannelEdgeInfo. -func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, +// keys for the target ChannelEdgeInfo1. +func (c *ChannelEdgeInfo1) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, bitcoinKey2 *btcec.PublicKey) { c.nodeKey1 = nodeKey1 @@ -96,7 +101,7 @@ func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) NodeKey1() (*btcec.PublicKey, error) { if c.nodeKey1 != nil { return c.nodeKey1, nil } @@ -117,7 +122,7 @@ func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) NodeKey2() (*btcec.PublicKey, error) { if c.nodeKey2 != nil { return c.nodeKey2, nil } @@ -137,7 +142,7 @@ func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) BitcoinKey1() (*btcec.PublicKey, error) { if c.bitcoinKey1 != nil { return c.bitcoinKey1, nil } @@ -157,7 +162,7 @@ func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { +func (c *ChannelEdgeInfo1) BitcoinKey2() (*btcec.PublicKey, error) { if c.bitcoinKey2 != nil { return c.bitcoinKey2, nil } @@ -172,7 +177,7 @@ func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { } // OtherNodeKeyBytes returns the node key bytes of the other end of the channel. -func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( +func (c *ChannelEdgeInfo1) OtherNodeKeyBytes(thisNodeKey []byte) ( [33]byte, error) { switch { @@ -185,3 +190,427 @@ func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( "this channel") } } + +// Copy returns a copy of the ChannelEdgeInfo. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Copy() ChannelEdgeInfo { + return &ChannelEdgeInfo1{ + ChannelID: c.ChannelID, + ChainHash: c.ChainHash, + NodeKey1Bytes: c.NodeKey1Bytes, + NodeKey2Bytes: c.NodeKey2Bytes, + BitcoinKey1Bytes: c.BitcoinKey1Bytes, + BitcoinKey2Bytes: c.BitcoinKey2Bytes, + Features: c.Features, + AuthProof: c.AuthProof, + ChannelPoint: c.ChannelPoint, + Capacity: c.Capacity, + ExtraOpaqueData: c.ExtraOpaqueData, + } +} + +// Node1Bytes returns bytes of the public key of node 1. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Node1Bytes() [33]byte { + return c.NodeKey1Bytes +} + +// Node2Bytes returns bytes of the public key of node 2. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) Node2Bytes() [33]byte { + return c.NodeKey2Bytes +} + +// GetChainHash returns the hash of the genesis block of the chain that the edge +// is on. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChainHash() chainhash.Hash { + return c.ChainHash +} + +// GetChanID returns the channel ID. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChanID() uint64 { + return c.ChannelID +} + +// GetAuthProof returns the ChannelAuthProof for the edge. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetAuthProof() ChannelAuthProof { + // Cant just return AuthProof cause you then run into the + // nil interface gotcha. + if c.AuthProof == nil { + return nil + } + + return c.AuthProof +} + +// GetCapacity returns the capacity of the channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetCapacity() btcutil.Amount { + return c.Capacity +} + +// SetAuthProof sets the proof of the channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) SetAuthProof(proof ChannelAuthProof) error { + if proof == nil { + c.AuthProof = nil + + return nil + } + + p, ok := proof.(*ChannelAuthProof1) + if !ok { + return fmt.Errorf("expected type ChannelAuthProof1 for "+ + "ChannelEdgeInfo1, got %T", proof) + } + + c.AuthProof = p + + return nil +} + +// GetChanPoint returns the outpoint of the funding transaction of the channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) GetChanPoint() wire.OutPoint { + return c.ChannelPoint +} + +// FundingScript returns the pk script for the funding output of the +// channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo1) FundingScript() ([]byte, error) { + legacyFundingScript := func() ([]byte, error) { + witnessScript, err := input.GenMultiSigScript( + c.BitcoinKey1Bytes[:], c.BitcoinKey2Bytes[:], + ) + if err != nil { + return nil, err + } + pkScript, err := input.WitnessScriptHash(witnessScript) + if err != nil { + return nil, err + } + + return pkScript, nil + } + + if len(c.Features) == 0 { + return legacyFundingScript() + } + + // TODO(elle): remove this taproot funding script logic once + // ChannelEdgeInfo2 is being used. + + // In order to make the correct funding script, we'll need to parse the + // chanFeatures bytes into a feature vector we can interact with. + rawFeatures := lnwire.NewRawFeatureVector() + err := rawFeatures.Decode(bytes.NewReader(c.Features)) + if err != nil { + return nil, fmt.Errorf("unable to parse chan feature "+ + "bits: %w", err) + } + + chanFeatureBits := lnwire.NewFeatureVector( + rawFeatures, lnwire.Features, + ) + if chanFeatureBits.HasFeature( + lnwire.SimpleTaprootChannelsOptionalStaging, + ) { + + pubKey1, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:]) + if err != nil { + return nil, err + } + pubKey2, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:]) + if err != nil { + return nil, err + } + + fundingScript, _, err := input.GenTaprootFundingScript( + pubKey1, pubKey2, 0, + ) + if err != nil { + return nil, err + } + + return fundingScript, nil + } + + return legacyFundingScript() +} + +// A compile-time check to ensure that ChannelEdgeInfo1 implements the +// ChannelEdgeInfo interface. +var _ ChannelEdgeInfo = (*ChannelEdgeInfo1)(nil) + +// ChannelEdgeInfo2 describes the information about a channel announced with +// lnwire.ChannelAnnouncement2 that we will persist. +type ChannelEdgeInfo2 struct { + lnwire.ChannelAnnouncement2 + + // ChannelPoint is the funding outpoint of the channel. This can be + // used to uniquely identify the channel within the channel graph. + ChannelPoint wire.OutPoint + + // FundingPkScript is the funding transaction's pk script. We persist + // this since there are some cases in which this will not be derivable + // using the contents of the announcement. In that case, we still want + // quick access to the funding script so that we can register for spend + // notifications. + FundingPkScript []byte + + // AuthProof is the authentication proof for this channel. + AuthProof *ChannelAuthProof2 + + nodeKey1 *btcec.PublicKey + nodeKey2 *btcec.PublicKey +} + +// Copy returns a copy of the ChannelEdgeInfo. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Copy() ChannelEdgeInfo { + return &ChannelEdgeInfo2{ + ChannelAnnouncement2: lnwire.ChannelAnnouncement2{ + ChainHash: c.ChainHash, + Features: c.Features, + ShortChannelID: c.ShortChannelID, + Capacity: c.Capacity, + NodeID1: c.NodeID1, + NodeID2: c.NodeID2, + BitcoinKey1: c.BitcoinKey1, + BitcoinKey2: c.BitcoinKey2, + MerkleRootHash: c.MerkleRootHash, + ExtraOpaqueData: c.ExtraOpaqueData, + }, + ChannelPoint: c.ChannelPoint, + AuthProof: c.AuthProof, + } +} + +// Node1Bytes returns bytes of the public key of node 1. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Node1Bytes() [33]byte { + return c.NodeID1.Val +} + +// Node2Bytes returns bytes of the public key of node 2. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) Node2Bytes() [33]byte { + return c.NodeID2.Val +} + +// GetChainHash returns the hash of the genesis block of the chain that the edge +// is on. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChainHash() chainhash.Hash { + return c.ChainHash.Val +} + +// GetChanID returns the channel ID. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChanID() uint64 { + return c.ShortChannelID.Val.ToUint64() +} + +// GetAuthProof returns the ChannelAuthProof for the edge. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetAuthProof() ChannelAuthProof { + // Cant just return AuthProof cause you then run into the + // nil interface gotcha. + if c.AuthProof == nil { + return nil + } + + return c.AuthProof +} + +// GetCapacity returns the capacity of the channel. +// +// NOTE: this is part of the models.ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetCapacity() btcutil.Amount { + return btcutil.Amount(c.Capacity.Val) +} + +// SetAuthProof sets the proof of the channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) SetAuthProof(proof ChannelAuthProof) error { + if proof == nil { + c.AuthProof = nil + + return nil + } + + p, ok := proof.(*ChannelAuthProof2) + if !ok { + return fmt.Errorf("expected type ChannelAuthProof2 for "+ + "ChannelEdgeInfo2, got %T", proof) + } + + c.AuthProof = p + + return nil +} + +// GetChanPoint returns the outpoint of the funding transaction of the channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) GetChanPoint() wire.OutPoint { + return c.ChannelPoint +} + +// FundingScript returns the pk script for the funding output of the +// channel. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) FundingScript() ([]byte, error) { + var ( + pubKey1 *btcec.PublicKey + pubKey2 *btcec.PublicKey + err error + ) + c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { + pubKey1, err = btcec.ParsePubKey(key.Val[:]) + }) + if err != nil { + return nil, err + } + + c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { + pubKey2, err = btcec.ParsePubKey(key.Val[:]) + }) + if err != nil { + return nil, err + } + + // If both bitcoin keys are not present in the announcement, then we + // should previously have stored the funding script found on-chain. + if pubKey1 == nil || pubKey2 == nil { + if len(c.FundingPkScript) == 0 { + return nil, fmt.Errorf("expected a funding pk script " + + "since no bitcoin keys were provided") + } + + return c.FundingPkScript, nil + } + + // Initially we set the tweak to an empty byte array. If a merkle root + // hash is provided in the announcement then we use that to set the + // tweak but otherwise, the empty tweak will have the same effect as a + // BIP86 tweak. + var tweak []byte + c.MerkleRootHash.WhenSome( + func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { + tweak = hash.Val[:] + }, + ) + + // Calculate the internal key by computing the MuSig2 combination of the + // two public keys. + internalKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{pubKey1, pubKey2}, true, + ) + if err != nil { + return nil, err + } + + // Now, determine the tweak to be added to the internal key. If the + // tweak is empty, then this will effectively be a BIP86 tweak. + tapTweakHash := chainhash.TaggedHash( + chainhash.TagTapTweak, schnorr.SerializePubKey( + internalKey.FinalKey, + ), tweak, + ) + + // Compute the final output key. + combinedKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{pubKey1, pubKey2}, true, + musig2.WithKeyTweaks(musig2.KeyTweakDesc{ + Tweak: *tapTweakHash, + IsXOnly: true, + }), + ) + if err != nil { + return nil, err + } + + // Now that we have the combined key, we can create a taproot pkScript + // from this, and then make the txout given the amount. + fundingScript, err := input.PayToTaprootScript(combinedKey.FinalKey) + if err != nil { + return nil, fmt.Errorf("unable to make taproot pkscript: %w", + err) + } + + return fundingScript, nil +} + +// NodeKey1 is the identity public key of the "first" node that was involved in +// the creation of this channel. A node is considered "first" if the +// lexicographical ordering the its serialized public key is "smaller" than +// that of the other node involved in channel creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) NodeKey1() (*btcec.PublicKey, error) { + if c.nodeKey1 != nil { + return c.nodeKey1, nil + } + + key, err := btcec.ParsePubKey(c.NodeID1.Val[:]) + if err != nil { + return nil, err + } + c.nodeKey1 = key + + return key, nil +} + +// NodeKey2 is the identity public key of the "second" node that was +// involved in the creation of this channel. A node is considered +// "second" if the lexicographical ordering the its serialized public +// key is "larger" than that of the other node involved in channel +// creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +// +// NOTE: this is part of the ChannelEdgeInfo interface. +func (c *ChannelEdgeInfo2) NodeKey2() (*btcec.PublicKey, error) { + if c.nodeKey2 != nil { + return c.nodeKey2, nil + } + + key, err := btcec.ParsePubKey(c.NodeID2.Val[:]) + if err != nil { + return nil, err + } + c.nodeKey2 = key + + return key, nil +} + +// A compile-time check to ensure that ChannelEdgeInfo2 implements the +// ChannelEdgeInfo interface. +var _ ChannelEdgeInfo = (*ChannelEdgeInfo2)(nil) diff --git a/channeldb/models/channel_edge_policy.go b/channeldb/models/channel_edge_policy.go index 322ce3cd09..1393b94772 100644 --- a/channeldb/models/channel_edge_policy.go +++ b/channeldb/models/channel_edge_policy.go @@ -1,18 +1,20 @@ package models import ( + "fmt" "time" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" ) -// ChannelEdgePolicy represents a *directed* edge within the channel graph. For +// ChannelEdgePolicy1 represents a *directed* edge within the channel graph. For // each channel in the database, there are two distinct edges: one for each // possible direction of travel along the channel. The edges themselves hold // information concerning fees, and minimum time-lock information which is // utilized during path finding. -type ChannelEdgePolicy struct { +type ChannelEdgePolicy1 struct { // SigBytes is the raw bytes of the signature of the channel edge // policy. We'll only parse these if the caller needs to access the // signature for validation purposes. Do not set SigBytes directly, but @@ -78,7 +80,7 @@ type ChannelEdgePolicy struct { // // NOTE: By having this method to access an attribute, we ensure we only need // to fully deserialize the signature if absolutely necessary. -func (c *ChannelEdgePolicy) Signature() (*ecdsa.Signature, error) { +func (c *ChannelEdgePolicy1) Signature() (*ecdsa.Signature, error) { if c.sig != nil { return c.sig, nil } @@ -95,21 +97,158 @@ func (c *ChannelEdgePolicy) Signature() (*ecdsa.Signature, error) { // SetSigBytes updates the signature and invalidates the cached parsed // signature. -func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { +func (c *ChannelEdgePolicy1) SetSigBytes(sig []byte) { c.SigBytes = sig c.sig = nil } // IsDisabled determines whether the edge has the disabled bit set. -func (c *ChannelEdgePolicy) IsDisabled() bool { +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) IsDisabled() bool { return c.ChannelFlags.IsDisabled() } // ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over // the passed active payment channel. This value is currently computed as // specified in BOLT07, but will likely change in the near future. -func (c *ChannelEdgePolicy) ComputeFee( +func (c *ChannelEdgePolicy1) ComputeFee( amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts } + +// SCID returns the short channel ID of the channel being referred to. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) SCID() lnwire.ShortChannelID { + return lnwire.NewShortChanIDFromInt(c.ChannelID) +} + +// IsNode1 returns true if the update was constructed by node 1 of the +// channel. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) IsNode1() bool { + return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 +} + +// GetToNode returns the pub key of the node that did not produce the update. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) GetToNode() [33]byte { + return c.ToNode +} + +// ForwardingPolicy return the various forwarding policy rules set by the +// update. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) ForwardingPolicy() *lnwire.ForwardingPolicy { + return &lnwire.ForwardingPolicy{ + TimeLockDelta: c.TimeLockDelta, + BaseFee: c.FeeBaseMSat, + FeeRate: c.FeeProportionalMillionths, + MinHTLC: c.MinHTLC, + HasMaxHTLC: c.MessageFlags.HasMaxHtlc(), + MaxHTLC: c.MaxHTLC, + } +} + +// Before compares this update against the passed update and returns true if +// this update has a lower timestamp than the passed one. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) Before(policy ChannelEdgePolicy) (bool, error) { + other, ok := policy.(*ChannelEdgePolicy1) + if !ok { + return false, fmt.Errorf("can't compare type %T to type "+ + "ChannelEdgePolicy1", policy) + } + + return c.LastUpdate.Before(other.LastUpdate), nil +} + +// AfterUpdateMsg compares this update against the passed +// lnwire.ChannelUpdate message and returns true if this update is newer than +// the passed one. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) AfterUpdateMsg(msg lnwire.ChannelUpdate) (bool, + error) { + + upd, ok := msg.(*lnwire.ChannelUpdate1) + if !ok { + return false, fmt.Errorf("expected *lnwire.ChannelUpdate1 to "+ + "be coupled with ChannelEdgePolicy1, got: %T", msg) + } + + timestamp := time.Unix(int64(upd.Timestamp), 0) + + return c.LastUpdate.After(timestamp), nil +} + +// Sig returns the signature of the update message. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy1) Sig() (input.Signature, error) { + return c.Signature() +} + +// A compile-time check to ensure that ChannelEdgePolicy1 implements the +// ChannelEdgePolicy interface. +var _ ChannelEdgePolicy = (*ChannelEdgePolicy1)(nil) + +type ChannelEdgePolicy2 struct { + lnwire.ChannelUpdate2 + + ToNode [33]byte +} + +// Sig returns the signature of the update message. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy2) Sig() (input.Signature, error) { + return c.Signature.ToSignature() +} + +// AfterUpdateMsg compares this update against the passed lnwire.ChannelUpdate +// message and returns true if this update is newer than the passed one. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy2) AfterUpdateMsg(msg lnwire.ChannelUpdate) (bool, + error) { + + upd, ok := msg.(*lnwire.ChannelUpdate2) + if !ok { + return false, fmt.Errorf("expected *lnwire.ChannelUpdate2 to "+ + "be coupled with ChannelEdgePolicy2, got: %T", msg) + } + + return c.BlockHeight.Val > upd.BlockHeight.Val, nil +} + +// Before compares this update against the passed update and returns true if +// this update has a lower timestamp than the passed one. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy2) Before(policy ChannelEdgePolicy) (bool, error) { + other, ok := policy.(*ChannelEdgePolicy2) + if !ok { + return false, fmt.Errorf("can't compare type %T to type "+ + "ChannelEdgePolicy2", policy) + } + + return c.BlockHeight.Val < other.BlockHeight.Val, nil +} + +// GetToNode returns the pub key of the node that did not produce the update. +// +// NOTE: This is part of the ChannelEdgePolicy interface. +func (c *ChannelEdgePolicy2) GetToNode() [33]byte { + return c.ToNode +} + +// A compile-time check to ensure that ChannelEdgePolicy2 implements the +// ChannelEdgePolicy interface. +var _ ChannelEdgePolicy = (*ChannelEdgePolicy2)(nil) diff --git a/channeldb/models/interfaces.go b/channeldb/models/interfaces.go new file mode 100644 index 0000000000..f250bc95c6 --- /dev/null +++ b/channeldb/models/interfaces.go @@ -0,0 +1,97 @@ +package models + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ChannelEdgeInfo is an interface that describes a channel announcement. +type ChannelEdgeInfo interface { //nolint:interfacebloat + // GetChainHash returns the hash of the genesis block of the chain that + // the edge is on. + GetChainHash() chainhash.Hash + + // GetChanID returns the channel ID. + GetChanID() uint64 + + // GetAuthProof returns the ChannelAuthProof for the edge. + GetAuthProof() ChannelAuthProof + + // GetCapacity returns the capacity of the channel. + GetCapacity() btcutil.Amount + + // SetAuthProof sets the proof of the channel. + SetAuthProof(ChannelAuthProof) error + + // NodeKey1 returns the public key of node 1. + NodeKey1() (*btcec.PublicKey, error) + + // NodeKey2 returns the public key of node 2. + NodeKey2() (*btcec.PublicKey, error) + + // Node1Bytes returns bytes of the public key of node 1. + Node1Bytes() [33]byte + + // Node2Bytes returns bytes the public key of node 2. + Node2Bytes() [33]byte + + // GetChanPoint returns the outpoint of the funding transaction of the + // channel. + GetChanPoint() wire.OutPoint + + // FundingScript returns the pk script for the funding output of the + // channel. + FundingScript() ([]byte, error) + + // Copy returns a copy of the ChannelEdgeInfo. + Copy() ChannelEdgeInfo +} + +// ChannelAuthProof is an interface that describes the proof of ownership of +// a channel. +type ChannelAuthProof interface { + // isChanAuthProof is a no-op method used to ensure that a struct must + // explicitly inherit this interface to be considered a + // ChannelAuthProof type. + isChanAuthProof() +} + +// ChannelEdgePolicy is an interface that describes an update to the forwarding +// rules of a channel. +type ChannelEdgePolicy interface { + // SCID returns the short channel ID of the channel being referred to. + SCID() lnwire.ShortChannelID + + // IsDisabled returns true if the update is indicating that the channel + // should be considered disabled. + IsDisabled() bool + + // IsNode1 returns true if the update was constructed by node 1 of the + // channel. + IsNode1() bool + + // GetToNode returns the pub key of the node that did not produce the + // update. + GetToNode() [33]byte + + // ForwardingPolicy return the various forwarding policy rules set by + // the update. + ForwardingPolicy() *lnwire.ForwardingPolicy + + // Before compares this update against the passed update and returns + // true if this update has a lower timestamp than the passed one. + Before(policy ChannelEdgePolicy) (bool, error) + + // AfterUpdateMsg compares this update against the passed + // lnwire.ChannelUpdate message and returns true if this update is + // newer than the passed one. + // TODO(elle): combine with Before? + AfterUpdateMsg(msg lnwire.ChannelUpdate) (bool, error) + + // Sig returns the signature of the update message. + Sig() (input.Signature, error) +} diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index c6b2b9df52..faefc620ce 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -191,15 +191,17 @@ type WaitingProofKey [9]byte // needed to make channel proof exchange persistent, so that after client // restart we may receive remote/local half proof and process it. type WaitingProof struct { - *lnwire.AnnounceSignatures + *lnwire.AnnounceSignatures1 isRemote bool } // NewWaitingProof constructs a new waiting prof instance. -func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof { +func NewWaitingProof(isRemote bool, + proof *lnwire.AnnounceSignatures1) *WaitingProof { + return &WaitingProof{ - AnnounceSignatures: proof, - isRemote: isRemote, + AnnounceSignatures1: proof, + isRemote: isRemote, } } @@ -238,7 +240,7 @@ func (p *WaitingProof) Encode(w io.Writer) error { return fmt.Errorf("expect io.Writer to be *bytes.Buffer") } - if err := p.AnnounceSignatures.Encode(buf, 0); err != nil { + if err := p.AnnounceSignatures1.Encode(buf, 0); err != nil { return err } @@ -252,11 +254,12 @@ func (p *WaitingProof) Decode(r io.Reader) error { return err } - msg := &lnwire.AnnounceSignatures{} + msg := &lnwire.AnnounceSignatures1{} if err := msg.Decode(r, 0); err != nil { return err } - (*p).AnnounceSignatures = msg + p.AnnounceSignatures1 = msg + return nil } diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index 1a00c829e0..d7113d9e75 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -18,7 +18,7 @@ func TestWaitingProofStore(t *testing.T) { db, err := MakeTestDB(t) require.NoError(t, err, "failed to make test database") - proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ + proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures1{ NodeSignature: wireSig, BitcoinSignature: wireSig, ExtraOpaqueData: make([]byte, 0), diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 8cbca1277d..c0d9900b29 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -5,7 +5,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" "github.com/lightningnetwork/lnd/routing/route" @@ -61,7 +60,8 @@ type ChannelGraphTimeSeries interface { // specified short channel ID. If no channel updates are known for the // channel, then an empty slice will be returned. FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, + error) } // ChanSeries is an implementation of the ChannelGraphTimeSeries @@ -136,7 +136,7 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, if edge1 != nil { // We don't want to send channel updates that don't // conform to the spec (anymore). - err := graph.ValidateChannelUpdateFields(0, edge1) + err := edge1.Validate(0) if err != nil { log.Errorf("not sending invalid channel "+ "update %v: %v", edge1, err) @@ -145,7 +145,7 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, } } if edge2 != nil { - err := graph.ValidateChannelUpdateFields(0, edge2) + err := edge2.Validate(0) if err != nil { log.Errorf("not sending invalid channel "+ "update %v: %v", edge2, err) @@ -326,7 +326,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { chanInfo, e1, e2, err := c.graph.FetchChannelEdgesByID( shortChanID.ToUint64(), @@ -335,7 +335,7 @@ func (c *ChanSeries) FetchChanUpdates(chain chainhash.Hash, return nil, err } - chanUpdates := make([]*lnwire.ChannelUpdate, 0, 2) + chanUpdates := make([]*lnwire.ChannelUpdate1, 0, 2) if e1 != nil { chanUpdate, err := netann.ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 80fd576a23..f16e502bd7 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/netann" @@ -317,7 +318,7 @@ type Config struct { // SignAliasUpdate is used to re-sign a channel update using the // remote's alias if the option-scid-alias feature bit was negotiated. - SignAliasUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) // FindBaseByAlias finds the SCID stored in the graph by an alias SCID. @@ -340,6 +341,11 @@ type Config struct { // updates for a channel and returns true if the channel should be // considered a zombie based on these timestamps. IsStillZombieChannel func(time.Time, time.Time) bool + + // FetchTxBySCID queries the chain for the transaction with the given + // SCID. A quit channel can be passed in to cancel the query. + FetchTxBySCID func(chanID *lnwire.ShortChannelID, quit chan struct{}) ( + *wire.MsgTx, error) } // processedNetworkMsg is a wrapper around networkMsg and a boolean. It is @@ -540,10 +546,10 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper // EdgeWithInfo contains the information that is required to update an edge. type EdgeWithInfo struct { // Info describes the channel. - Info *models.ChannelEdgeInfo + Info *models.ChannelEdgeInfo1 // Edge describes the policy in one direction of the channel. - Edge *models.ChannelEdgePolicy + Edge *models.ChannelEdgePolicy1 } // PropagateChanPolicyUpdate signals the AuthenticatedGossiper to perform the @@ -836,9 +842,9 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // To avoid inserting edges in the graph for our own channels that we // have already closed, we ignore such channel announcements coming // from the remote. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: ownKey := d.selfKey.SerializeCompressed() - ownErr := fmt.Errorf("ignoring remote ChannelAnnouncement " + + ownErr := fmt.Errorf("ignoring remote ChannelAnnouncement1 " + "for own channel") if bytes.Equal(m.NodeID1[:], ownKey) || @@ -998,7 +1004,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { switch msg := message.msg.(type) { // Channel announcements are identified by the short channel id field. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: deDupKey := msg.ShortChannelID sender := route.NewVertex(message.source) @@ -1022,7 +1028,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // Channel updates are identified by the (short channel id, // channelflags) tuple. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: sender := route.NewVertex(message.source) deDupKey := channelUpdateID{ msg.ShortChannelID, @@ -1034,7 +1040,15 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { if ok { // If we already have seen this message, record its // timestamp. - oldTimestamp = mws.msg.(*lnwire.ChannelUpdate).Timestamp + update, ok := mws.msg.(*lnwire.ChannelUpdate1) + if !ok { + log.Errorf("Expected *lnwire.ChannelUpdate1, "+ + "got: %T", mws.msg) + + return + } + + oldTimestamp = update.Timestamp } // If we already had this message with a strictly newer @@ -1403,7 +1417,7 @@ func (d *AuthenticatedGossiper) networkHandler() { switch announcement.msg.(type) { // Channel announcement signatures are amongst the only // messages that we'll process serially. - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: emittedAnnouncements, _ := d.processNetworkAnnouncement( announcement, ) @@ -1569,10 +1583,10 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, var scid uint64 switch m := msg.(type) { - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: scid = m.ShortChannelID.ToUint64() - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: scid = m.ShortChannelID.ToUint64() default: @@ -1592,8 +1606,8 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // Iterate over all of our channels and check if any of them fall // within the prune interval or re-broadcast interval. type updateTuple struct { - info *models.ChannelEdgeInfo - edge *models.ChannelEdgePolicy + info *models.ChannelEdgeInfo1 + edge *models.ChannelEdgePolicy1 } var ( @@ -1602,8 +1616,8 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { ) err := d.cfg.Graph.ForAllOutgoingChannels(func( _ kvdb.RTx, - info *models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy) error { + info *models.ChannelEdgeInfo1, + edge *models.ChannelEdgePolicy1) error { // If there's no auth proof attached to this edge, it means // that it is a private channel not meant to be announced to @@ -1808,8 +1822,8 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( } // remotePubFromChanInfo returns the public key of the remote peer given a -// ChannelEdgeInfo that describe a channel we have with them. -func remotePubFromChanInfo(chanInfo *models.ChannelEdgeInfo, +// ChannelEdgeInfo1 that describe a channel we have with them. +func remotePubFromChanInfo(chanInfo *models.ChannelEdgeInfo1, chanFlags lnwire.ChanUpdateChanFlags) [33]byte { var remotePubKey [33]byte @@ -1831,8 +1845,8 @@ func remotePubFromChanInfo(chanInfo *models.ChannelEdgeInfo, // to receive the remote peer's proof, while the remote peer is able to fully // assemble the proof and craft the ChannelAnnouncement. func (d *AuthenticatedGossiper) processRejectedEdge( - chanAnnMsg *lnwire.ChannelAnnouncement, - proof *models.ChannelAuthProof) ([]networkMsg, error) { + chanAnnMsg *lnwire.ChannelAnnouncement1, + proof *models.ChannelAuthProof1) ([]networkMsg, error) { // First, we'll fetch the state of the channel as we know if from the // database. @@ -1865,7 +1879,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge( if err != nil { return nil, err } - err = graph.ValidateChannelAnn(chanAnn) + err = chanAnn.Validate(d.fetchPKScript) if err != nil { err := fmt.Errorf("assembled channel announcement proof "+ "for shortChanID=%v isn't valid: %v", @@ -1909,6 +1923,27 @@ func (d *AuthenticatedGossiper) processRejectedEdge( return announcements, nil } +// fetchPKScript fetches the output script for the given SCID. +func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( + []byte, error) { + + tx, err := d.cfg.FetchTxBySCID(chanID, d.quit) + if err != nil { + return nil, err + } + + outputLocator := chanvalidate.ShortChanIDChanLocator{ + ID: *chanID, + } + + output, _, err := outputLocator.Locate(tx) + if err != nil { + return nil, err + } + + return output.PkScript, nil +} + // addNode processes the given node announcement, and adds it to our channel // graph. func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement, @@ -2016,19 +2051,19 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // *creation* of a new channel within the network. This only advertises // the existence of a channel and not yet the routing policies in // either direction of the channel. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: return d.handleChanAnnouncement(nMsg, msg, schedulerOp) // A new authenticated channel edge update has arrived. This indicates // that the directional information for an already known channel has // been updated. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return d.handleChanUpdate(nMsg, msg, schedulerOp) // A new signature announcement has been received. This indicates // willingness of nodes involved in the funding of a channel to // announce this new channel to the rest of the world. - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: return d.handleAnnSig(nMsg, msg) default: @@ -2041,11 +2076,11 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // processZombieUpdate determines whether the provided channel update should // resurrect a given zombie edge. // -// NOTE: only the NodeKey1Bytes and NodeKey2Bytes members of the ChannelEdgeInfo -// should be inspected. +// NOTE: only the NodeKey1Bytes and NodeKey2Bytes members of the +// ChannelEdgeInfo1 should be inspected. func (d *AuthenticatedGossiper) processZombieUpdate( - chanInfo *models.ChannelEdgeInfo, scid lnwire.ShortChannelID, - msg *lnwire.ChannelUpdate) error { + chanInfo *models.ChannelEdgeInfo1, scid lnwire.ShortChannelID, + msg *lnwire.ChannelUpdate1) error { // The least-significant bit in the flag on the channel update tells us // which edge is being updated. @@ -2068,7 +2103,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( "with chan_id=%v", msg.ShortChannelID) } - err := graph.VerifyChannelUpdateSignature(msg, pubKey) + err := msg.VerifySig(pubKey) if err != nil { return fmt.Errorf("unable to verify channel "+ "update signature: %v", err) @@ -2116,7 +2151,7 @@ func (d *AuthenticatedGossiper) fetchNodeAnn( // MessageStore is seen as stale by the current graph. func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { switch msg := msg.(type) { - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: chanInfo, _, _, err := d.cfg.Graph.GetChannelByID( msg.ShortChannelID, ) @@ -2138,7 +2173,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // can safely delete the local proof from the database. return chanInfo.AuthProof != nil - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: _, p1, p2, err := d.cfg.Graph.GetChannelByID(msg.ShortChannelID) // If the channel cannot be found, it is most likely a leftover @@ -2156,7 +2191,7 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // Otherwise, we'll retrieve the correct policy that we // currently have stored within our graph to check if this // message is stale by comparing its timestamp. - var p *models.ChannelEdgePolicy + var p *models.ChannelEdgePolicy1 if msg.ChannelFlags&lnwire.ChanUpdateDirection == 0 { p = p1 } else { @@ -2181,9 +2216,9 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. -func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate, error) { +func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo1, + edge *models.ChannelEdgePolicy1) (*lnwire.ChannelAnnouncement1, + *lnwire.ChannelUpdate1, error) { // Parse the unsigned edge into a channel update. chanUpdate := netann.UnsignedChannelUpdateFromEdge(info, edge) @@ -2205,7 +2240,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. - err = graph.ValidateChannelUpdateAnn( + err = lnwire.ValidateChannelUpdateAnn( d.selfKey, info.Capacity, chanUpdate, ) if err != nil { @@ -2221,10 +2256,10 @@ func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, // We'll also create the original channel announcement so the two can // be broadcast along side each other (if necessary), but only if we // have a full channel announcement for this channel. - var chanAnn *lnwire.ChannelAnnouncement + var chanAnn *lnwire.ChannelAnnouncement1 if info.AuthProof != nil { chanID := lnwire.NewShortChanIDFromInt(info.ChannelID) - chanAnn = &lnwire.ChannelAnnouncement{ + chanAnn = &lnwire.ChannelAnnouncement1{ ShortChannelID: chanID, NodeID1: info.NodeKey1Bytes, NodeID2: info.NodeKey2Bytes, @@ -2271,8 +2306,8 @@ func (d *AuthenticatedGossiper) SyncManager() *SyncManager { // IsKeepAliveUpdate determines whether this channel update is considered a // keep-alive update based on the previous channel update processed for the same // direction. -func IsKeepAliveUpdate(update *lnwire.ChannelUpdate, - prev *models.ChannelEdgePolicy) bool { +func IsKeepAliveUpdate(update *lnwire.ChannelUpdate1, + prev *models.ChannelEdgePolicy1) bool { // Both updates should be from the same direction. if update.ChannelFlags&lnwire.ChanUpdateDirection != @@ -2396,16 +2431,16 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, // handleChanAnnouncement processes a new channel announcement. func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, - ann *lnwire.ChannelAnnouncement, + ann *lnwire.ChannelAnnouncement1, ops []batch.SchedulerOption) ([]networkMsg, bool) { - log.Debugf("Processing ChannelAnnouncement: peer=%v, short_chan_id=%v", + log.Debugf("Processing ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, ann.ShortChannelID.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { - err := fmt.Errorf("ignoring ChannelAnnouncement from chain=%v"+ + err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ ", gossiper on chain=%v", ann.ChainHash, d.cfg.ChainHash) log.Errorf(err.Error()) @@ -2461,9 +2496,9 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If this is a remote channel announcement, then we'll validate all // the signatures within the proof as it should be well formed. - var proof *models.ChannelAuthProof + var proof *models.ChannelAuthProof1 if nMsg.isRemote { - if err := graph.ValidateChannelAnn(ann); err != nil { + if err := ann.Validate(d.fetchPKScript); err != nil { err := fmt.Errorf("unable to validate announcement: "+ "%v", err) @@ -2481,7 +2516,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // If the proof checks out, then we'll save the proof itself to // the database so we can fetch it later when gossiping with // other nodes. - proof = &models.ChannelAuthProof{ + proof = &models.ChannelAuthProof1{ NodeSig1Bytes: ann.NodeSig1.ToSignatureBytes(), NodeSig2Bytes: ann.NodeSig2.ToSignatureBytes(), BitcoinSig1Bytes: ann.BitcoinSig1.ToSignatureBytes(), @@ -2498,7 +2533,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, return nil, false } - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: ann.ShortChannelID.ToUint64(), ChainHash: ann.ChainHash, NodeKey1Bytes: ann.NodeID1, @@ -2627,7 +2662,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // Reprocess the message, making sure we return an // error to the original caller in case the gossiper // shuts down. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: log.Debugf("Reprocessing ChannelUpdate for "+ "shortChanID=%v", msg.ShortChannelID.ToUint64()) @@ -2663,7 +2698,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, nMsg.err <- nil - log.Debugf("Processed ChannelAnnouncement: peer=%v, short_chan_id=%v", + log.Debugf("Processed ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, ann.ShortChannelID.ToUint64()) return announcements, true @@ -2671,7 +2706,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // handleChanUpdate processes a new channel update. func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, - upd *lnwire.ChannelUpdate, + upd *lnwire.ChannelUpdate1, ops []batch.SchedulerOption) ([]networkMsg, bool) { log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", @@ -2865,7 +2900,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // being updated. var ( pubKey *btcec.PublicKey - edgeToUpdate *models.ChannelEdgePolicy + edgeToUpdate *models.ChannelEdgePolicy1 ) direction := upd.ChannelFlags & lnwire.ChanUpdateDirection switch direction { @@ -2884,7 +2919,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // Validate the channel announcement with the expected public key and // channel capacity. In the case of an invalid channel update, we'll // return an error to the caller and exit early. - err = graph.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, upd) + err = lnwire.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, upd) if err != nil { rErr := fmt.Errorf("unable to validate channel update "+ "announcement for short_chan_id=%v: %v", @@ -2954,7 +2989,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // different alias. This might mean that SigBytes is incorrect as it // signs a different SCID than the database SCID, but since there will // only be a difference if AuthProof == nil, this is fine. - update := &models.ChannelEdgePolicy{ + update := &models.ChannelEdgePolicy1{ SigBytes: upd.Signature.ToSignatureBytes(), ChannelID: chanInfo.ChannelID, LastUpdate: timestamp, @@ -3075,7 +3110,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // handleAnnSig processes a new announcement signatures message. func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, - ann *lnwire.AnnounceSignatures) ([]networkMsg, bool) { + ann *lnwire.AnnounceSignatures1) ([]networkMsg, bool) { needBlockHeight := ann.ShortChannelID.BlockHeight + d.cfg.ProofMatureDelta @@ -3266,7 +3301,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // We now have both halves of the channel announcement proof, then // we'll reconstruct the initial announcement so we can validate it // shortly below. - var dbProof models.ChannelAuthProof + var dbProof models.ChannelAuthProof1 if isFirstNode { dbProof.NodeSig1Bytes = ann.NodeSignature.ToSignatureBytes() dbProof.NodeSig2Bytes = oppProof.NodeSignature.ToSignatureBytes() @@ -3290,7 +3325,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // With all the necessary components assembled validate the full // channel announcement proof. - if err := graph.ValidateChannelAnn(chanAnn); err != nil { + if err := chanAnn.Validate(d.fetchPKScript); err != nil { err := fmt.Errorf("channel announcement proof for "+ "short_chan_id=%v isn't valid: %v", shortChanID, err) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 7cfc7bce8f..6edd04c9a2 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -92,8 +92,8 @@ type mockGraphSource struct { mu sync.Mutex nodes []channeldb.LightningNode - infos map[uint64]models.ChannelEdgeInfo - edges map[uint64][]models.ChannelEdgePolicy + infos map[uint64]models.ChannelEdgeInfo1 + edges map[uint64][]models.ChannelEdgePolicy1 zombies map[uint64][][33]byte chansToReject map[uint64]struct{} } @@ -101,8 +101,8 @@ type mockGraphSource struct { func newMockRouter(height uint32) *mockGraphSource { return &mockGraphSource{ bestHeight: height, - infos: make(map[uint64]models.ChannelEdgeInfo), - edges: make(map[uint64][]models.ChannelEdgePolicy), + infos: make(map[uint64]models.ChannelEdgeInfo1), + edges: make(map[uint64][]models.ChannelEdgePolicy1), zombies: make(map[uint64][][33]byte), chansToReject: make(map[uint64]struct{}), } @@ -120,7 +120,7 @@ func (r *mockGraphSource) AddNode(node *channeldb.LightningNode, return nil } -func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo, +func (r *mockGraphSource) AddEdge(info *models.ChannelEdgeInfo1, _ ...batch.SchedulerOption) error { r.mu.Lock() @@ -145,14 +145,14 @@ func (r *mockGraphSource) queueValidationFail(chanID uint64) { r.chansToReject[chanID] = struct{}{} } -func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy, +func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy1, _ ...batch.SchedulerOption) error { r.mu.Lock() defer r.mu.Unlock() if len(r.edges[edge.ChannelID]) == 0 { - r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2) + r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy1, 2) } if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { @@ -169,7 +169,7 @@ func (r *mockGraphSource) CurrentBlockHeight() (uint32, error) { } func (r *mockGraphSource) AddProof(chanID lnwire.ShortChannelID, - proof *models.ChannelAuthProof) error { + proof *models.ChannelAuthProof1) error { r.mu.Lock() defer r.mu.Unlock() @@ -191,8 +191,8 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) } func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, - i *models.ChannelEdgeInfo, - c *models.ChannelEdgePolicy) error) error { + i *models.ChannelEdgeInfo1, + c *models.ChannelEdgePolicy1) error) error { r.mu.Lock() defer r.mu.Unlock() @@ -223,9 +223,9 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, } func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { r.mu.Lock() defer r.mu.Unlock() @@ -238,7 +238,7 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return nil, nil, nil, channeldb.ErrEdgeNotFound } - return &models.ChannelEdgeInfo{ + return &models.ChannelEdgeInfo1{ NodeKey1Bytes: pubKeys[0], NodeKey2Bytes: pubKeys[1], }, nil, nil, channeldb.ErrZombieEdge @@ -249,13 +249,13 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( return &chanInfo, nil, nil, nil } - var edge1 *models.ChannelEdgePolicy - if !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy{}) { + var edge1 *models.ChannelEdgePolicy1 + if !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}) { edge1 = &edges[0] } - var edge2 *models.ChannelEdgePolicy - if !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy{}) { + var edge2 *models.ChannelEdgePolicy1 + if !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}) { edge2 = &edges[1] } @@ -355,12 +355,12 @@ func (r *mockGraphSource) IsStaleEdgePolicy(chanID lnwire.ShortChannelID, switch { case flags&lnwire.ChanUpdateDirection == 0 && - !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy{}): + !reflect.DeepEqual(edges[0], models.ChannelEdgePolicy1{}): return !timestamp.After(edges[0].LastUpdate) case flags&lnwire.ChanUpdateDirection == 1 && - !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy{}): + !reflect.DeepEqual(edges[1], models.ChannelEdgePolicy1{}): return !timestamp.After(edges[1].LastUpdate) @@ -460,13 +460,13 @@ type annBatch struct { nodeAnn1 *lnwire.NodeAnnouncement nodeAnn2 *lnwire.NodeAnnouncement - chanAnn *lnwire.ChannelAnnouncement + chanAnn *lnwire.ChannelAnnouncement1 - chanUpdAnn1 *lnwire.ChannelUpdate - chanUpdAnn2 *lnwire.ChannelUpdate + chanUpdAnn1 *lnwire.ChannelUpdate1 + chanUpdAnn2 *lnwire.ChannelUpdate1 - localProofAnn *lnwire.AnnounceSignatures - remoteProofAnn *lnwire.AnnounceSignatures + localProofAnn *lnwire.AnnounceSignatures1 + remoteProofAnn *lnwire.AnnounceSignatures1 } func createLocalAnnouncements(blockHeight uint32) (*annBatch, error) { @@ -497,7 +497,7 @@ func createAnnouncements(blockHeight uint32, key1, key2 *btcec.PrivateKey) (*ann return nil, err } - batch.remoteProofAnn = &lnwire.AnnounceSignatures{ + batch.remoteProofAnn = &lnwire.AnnounceSignatures1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -505,7 +505,7 @@ func createAnnouncements(blockHeight uint32, key1, key2 *btcec.PrivateKey) (*ann BitcoinSignature: batch.chanAnn.BitcoinSig2, } - batch.localProofAnn = &lnwire.AnnounceSignatures{ + batch.localProofAnn = &lnwire.AnnounceSignatures1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -569,12 +569,12 @@ func createNodeAnnouncement(priv *btcec.PrivateKey, func createUpdateAnnouncement(blockHeight uint32, flags lnwire.ChanUpdateChanFlags, nodeKey *btcec.PrivateKey, timestamp uint32, - extraBytes ...[]byte) (*lnwire.ChannelUpdate, error) { + extraBytes ...[]byte) (*lnwire.ChannelUpdate1, error) { var err error htlcMinMsat := lnwire.MilliSatoshi(prand.Int63()) - a := &lnwire.ChannelUpdate{ + a := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -602,7 +602,7 @@ func createUpdateAnnouncement(blockHeight uint32, return a, nil } -func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error { +func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate1) error { signer := mock.SingleSigner{Privkey: nodeKey} sig, err := netann.SignAnnouncement(&signer, testKeyLoc, a) if err != nil { @@ -619,9 +619,9 @@ func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error { func createAnnouncementWithoutProof(blockHeight uint32, key1, key2 *btcec.PublicKey, - extraBytes ...[]byte) *lnwire.ChannelAnnouncement { + extraBytes ...[]byte) *lnwire.ChannelAnnouncement1 { - a := &lnwire.ChannelAnnouncement{ + a := &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, TxIndex: 0, @@ -641,13 +641,13 @@ func createAnnouncementWithoutProof(blockHeight uint32, } func createRemoteChannelAnnouncement(blockHeight uint32, - extraBytes ...[]byte) (*lnwire.ChannelAnnouncement, error) { + extraBytes ...[]byte) (*lnwire.ChannelAnnouncement1, error) { return createChannelAnnouncement(blockHeight, remoteKeyPriv1, remoteKeyPriv2, extraBytes...) } func createChannelAnnouncement(blockHeight uint32, key1, key2 *btcec.PrivateKey, - extraBytes ...[]byte) (*lnwire.ChannelAnnouncement, error) { + extraBytes ...[]byte) (*lnwire.ChannelAnnouncement1, error) { a := createAnnouncementWithoutProof(blockHeight, key1.PubKey(), key2.PubKey(), extraBytes...) @@ -731,7 +731,7 @@ func createTestCtx(t *testing.T, startHeight uint32) (*testCtx, error) { return false } - signAliasUpdate := func(*lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return nil, nil @@ -1433,7 +1433,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { return false } - signAliasUpdate := func(*lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(*lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return nil, nil @@ -1511,7 +1511,7 @@ out: case msg := <-sentToPeer: // Since the ChannelUpdate will also be resent as it is // sent reliably, we'll need to filter it out. - if _, ok := msg.(*lnwire.AnnounceSignatures); !ok { + if _, ok := msg.(*lnwire.AnnounceSignatures1); !ok { continue } @@ -1737,9 +1737,10 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // We expect the gossiper to send this message to the remote peer. select { case msg := <-sentToPeer: - _, ok := msg.(*lnwire.ChannelAnnouncement) + _, ok := msg.(*lnwire.ChannelAnnouncement1) if !ok { - t.Fatalf("expected ChannelAnnouncement, instead got %T", msg) + t.Fatalf("expected ChannelAnnouncement1, instead got "+ + "%T", msg) } case <-time.After(2 * time.Second): t.Fatal("did not send local proof to peer") @@ -1836,7 +1837,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { t.Fatal("channel update not replaced in batch") } - assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate) { + assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate1) { channelKey := channelUpdateID{ ua3.ShortChannelID, ua3.ChannelFlags, @@ -2474,7 +2475,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Fatalf("remote update was not processed") } - // Check that the ChannelEdgePolicy was added to the graph. + // Check that the ChannelEdgePolicy1 was added to the graph. chanInfo, e1, e2, err = ctx.router.GetChannelByID( batch.chanUpdAnn1.ShortChannelID, ) @@ -2772,9 +2773,9 @@ func TestRetransmit(t *testing.T) { var chanAnn, chanUpd, nodeAnn int for _, msg := range anns { switch msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: chanAnn++ - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: chanUpd++ case *lnwire.NodeAnnouncement: nodeAnn++ @@ -3205,7 +3206,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // already been announced. We'll keep track of the old message that is // now stale to use later on. staleChannelUpdate := batch.chanUpdAnn1 - newChannelUpdate := &lnwire.ChannelUpdate{} + newChannelUpdate := &lnwire.ChannelUpdate1{} *newChannelUpdate = *staleChannelUpdate newChannelUpdate.Timestamp++ if err := signUpdate(selfKeyPriv, newChannelUpdate); err != nil { @@ -3259,9 +3260,9 @@ func TestSendChannelUpdateReliably(t *testing.T) { } switch msg := msg.(type) { - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: assertMessage(t, staleChannelUpdate, msg) - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: assertMessage(t, batch.localProofAnn, msg) default: t.Fatalf("send unexpected %v message", msg.MsgType()) @@ -3439,8 +3440,8 @@ out: var edgesToUpdate []EdgeWithInfo err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, - info *models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy) error { + info *models.ChannelEdgeInfo1, + edge *models.ChannelEdgePolicy1) error { edge.TimeLockDelta = uint16(newTimeLockDelta) edgesToUpdate = append(edgesToUpdate, EdgeWithInfo{ @@ -3461,7 +3462,7 @@ out: // being the channel our first private channel. for i := 0; i < numChannels-1; i++ { assertBroadcastMsg(t, ctx, func(msg lnwire.Message) error { - upd, ok := msg.(*lnwire.ChannelUpdate) + upd, ok := msg.(*lnwire.ChannelUpdate1) if !ok { return fmt.Errorf("channel update not "+ "broadcast, instead %T was", msg) @@ -3485,7 +3486,7 @@ out: // remote peer via the reliable sender. select { case msg := <-sentMsgs: - upd, ok := msg.(*lnwire.ChannelUpdate) + upd, ok := msg.(*lnwire.ChannelUpdate1) if !ok { t.Fatalf("channel update not "+ "broadcast, instead %T was", msg) @@ -3509,7 +3510,7 @@ out: for { select { case msg := <-ctx.broadcastedMessage: - if upd, ok := msg.msg.(*lnwire.ChannelUpdate); ok { + if upd, ok := msg.msg.(*lnwire.ChannelUpdate1); ok { if upd.ShortChannelID == firstChanID { t.Fatalf("chan update msg received: %v", spew.Sdump(msg)) @@ -3835,7 +3836,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { // We'll define a helper to assert whether updates should be rate // limited or not depending on their contents. - assertRateLimit := func(update *lnwire.ChannelUpdate, peer lnpeer.Peer, + assertRateLimit := func(update *lnwire.ChannelUpdate1, peer lnpeer.Peer, shouldRateLimit bool) { t.Helper() diff --git a/discovery/message_store.go b/discovery/message_store.go index 10fa51623a..025a89d67a 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -83,9 +83,9 @@ func NewMessageStore(db kvdb.Backend) (*MessageStore, error) { func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) { var shortChanID lnwire.ShortChannelID switch msg := msg.(type) { - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: shortChanID = msg.ShortChannelID - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: shortChanID = msg.ShortChannelID default: return shortChanID, ErrUnsupportedMessage @@ -160,7 +160,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // In the event that we're attempting to delete a ChannelUpdate // from the store, we'll make sure that we're actually deleting // the correct one as it can be overwritten. - if msg, ok := msg.(*lnwire.ChannelUpdate); ok { + if msg, ok := msg.(*lnwire.ChannelUpdate1); ok { // Deleting a value from a bucket that doesn't exist // acts as a NOP, so we'll return if a message doesn't // exist under this key. @@ -176,7 +176,13 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // If the timestamps don't match, then the update stored // should be the latest one, so we'll avoid deleting it. - if msg.Timestamp != dbMsg.(*lnwire.ChannelUpdate).Timestamp { + m, ok := dbMsg.(*lnwire.ChannelUpdate1) + if !ok { + return fmt.Errorf("expected "+ + "*lnwire.ChannelUpdate1, got: %T", + dbMsg) + } + if msg.Timestamp != m.Timestamp { return nil } } diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index e812c3f1a3..36c082e36f 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -52,15 +52,15 @@ func randCompressedPubKey(t *testing.T) [33]byte { return compressedPubKey } -func randAnnounceSignatures() *lnwire.AnnounceSignatures { - return &lnwire.AnnounceSignatures{ +func randAnnounceSignatures() *lnwire.AnnounceSignatures1 { + return &lnwire.AnnounceSignatures1{ ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), ExtraOpaqueData: make([]byte, 0), } } -func randChannelUpdate() *lnwire.ChannelUpdate { - return &lnwire.ChannelUpdate{ +func randChannelUpdate() *lnwire.ChannelUpdate1 { + return &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), ExtraOpaqueData: make([]byte, 0), } @@ -116,9 +116,9 @@ func TestMessageStoreMessages(t *testing.T) { for _, msg := range peerMsgs { var shortChanID uint64 switch msg := msg.(type) { - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: shortChanID = msg.ShortChannelID.ToUint64() - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: shortChanID = msg.ShortChannelID.ToUint64() default: t.Fatalf("found unexpected message type %T", msg) diff --git a/discovery/syncer.go b/discovery/syncer.go index 512c9f631f..a62063e980 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1412,9 +1412,11 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // set of channel announcements and channel updates. This will allow us // to quickly check if we should forward a chan ann, based on the known // channel updates for a channel. - chanUpdateIndex := make(map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate) + chanUpdateIndex := make( + map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate1, + ) for _, msg := range msgs { - chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate) + chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate1) if !ok { continue } @@ -1453,7 +1455,7 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // For each channel announcement message, we'll only send this // message if the channel updates for the channel are between // our time range. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: // First, we'll check if the channel updates are in // this message batch. chanUpdates, ok := chanUpdateIndex[msg.ShortChannelID] @@ -1484,7 +1486,7 @@ func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { // For each channel update, we'll only send if it the timestamp // is between our time range. - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: if passesFilter(msg.Timestamp) { msgsToSend = append(msgsToSend, msg) } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 15e2442e15..c381494446 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -52,7 +52,7 @@ type mockChannelGraphTimeSeries struct { annResp chan []lnwire.Message updateReq chan lnwire.ShortChannelID - updateResp chan []*lnwire.ChannelUpdate + updateResp chan []*lnwire.ChannelUpdate1 } func newMockChannelGraphTimeSeries( @@ -74,7 +74,7 @@ func newMockChannelGraphTimeSeries( annResp: make(chan []lnwire.Message, 1), updateReq: make(chan lnwire.ShortChannelID, 1), - updateResp: make(chan []*lnwire.ChannelUpdate, 1), + updateResp: make(chan []*lnwire.ChannelUpdate1, 1), } } @@ -149,7 +149,7 @@ func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, return <-m.annResp, nil } func (m *mockChannelGraphTimeSeries) FetchChanUpdates(chain chainhash.Hash, - shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate1, error) { m.updateReq <- shortChanID @@ -302,36 +302,36 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }, { // Ann tuple below horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(10), }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(10), Timestamp: unixStamp(5), }, }, { // Ann tuple above horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(15), }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(15), Timestamp: unixStamp(25002), }, }, { // Ann tuple beyond horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), }, }, { - msg: &lnwire.ChannelUpdate{ + msg: &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), Timestamp: unixStamp(999999), }, @@ -339,7 +339,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { { // Ann w/o an update at all, the update in the DB will // be below the horizon. - msg: &lnwire.ChannelAnnouncement{ + msg: &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), }, }, @@ -365,7 +365,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { } // If so, then we'll send back the missing update. - chanSeries.updateResp <- []*lnwire.ChannelUpdate{ + chanSeries.updateResp <- []*lnwire.ChannelUpdate1{ { ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), @@ -547,7 +547,7 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { // For this first response, we'll send back a proper // set of messages that should be echoed back. chanSeries.horizonResp <- []lnwire.Message{ - &lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(25), Timestamp: unixStamp(5), }, @@ -702,10 +702,10 @@ func TestGossipSyncerReplyShortChanIDs(t *testing.T) { } queryReply := []lnwire.Message{ - &lnwire.ChannelAnnouncement{ + &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), }, - &lnwire.ChannelUpdate{ + &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(20), Timestamp: unixStamp(999999), }, diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 869e235115..fe94e0468c 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -48,6 +48,12 @@ # Technical and Architectural Updates ## BOLT Spec Updates +* Add new [lnwire](https://github.com/lightningnetwork/lnd/pull/8044) messages + for the Gossip 1.75 protocol. + +* Add new [channeldb](https://github.com/lightningnetwork/lnd/pull/8164) types + required for the Gossip 1.75 protocol. + ## Testing ## Database diff --git a/funding/manager.go b/funding/manager.go index 8ad16005c7..844747734c 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -532,7 +532,7 @@ type Config struct { // DeleteAliasEdge allows the Manager to delete an alias channel edge // from the graph. It also returns our local to-be-deleted policy. DeleteAliasEdge func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy, error) + *models.ChannelEdgePolicy1, error) // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. @@ -3408,7 +3408,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, - ourPolicy *models.ChannelEdgePolicy) error { + ourPolicy *models.ChannelEdgePolicy1) error { chanID := lnwire.NewChanIDFromOutPoint(completeChan.FundingOutpoint) @@ -4124,9 +4124,9 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // chanAnnouncement encapsulates the two authenticated announcements that we // send out to the network after a new channel has been created locally. type chanAnnouncement struct { - chanAnn *lnwire.ChannelAnnouncement - chanUpdateAnn *lnwire.ChannelUpdate - chanProof *lnwire.AnnounceSignatures + chanAnn *lnwire.ChannelAnnouncement1 + chanUpdateAnn *lnwire.ChannelUpdate1 + chanProof *lnwire.AnnounceSignatures1 } // newChanAnnouncement creates the authenticated channel announcement messages @@ -4141,7 +4141,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, remotePubKey *btcec.PublicKey, localFundingKey *keychain.KeyDescriptor, remoteFundingKey *btcec.PublicKey, shortChanID lnwire.ShortChannelID, chanID lnwire.ChannelID, fwdMinHTLC, fwdMaxHTLC lnwire.MilliSatoshi, - ourPolicy *models.ChannelEdgePolicy, + ourPolicy *models.ChannelEdgePolicy1, chanType channeldb.ChannelType) (*chanAnnouncement, error) { chainHash := *f.cfg.Wallet.Cfg.NetParams.GenesisHash @@ -4149,7 +4149,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, // The unconditional section of the announcement is the ShortChannelID // itself which compactly encodes the location of the funding output // within the blockchain. - chanAnn := &lnwire.ChannelAnnouncement{ + chanAnn := &lnwire.ChannelAnnouncement1{ ShortChannelID: shortChanID, Features: lnwire.NewRawFeatureVector(), ChainHash: chainHash, @@ -4219,7 +4219,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, // We announce the channel with the default values. Some of // these values can later be changed by crafting a new ChannelUpdate. - chanUpdateAnn := &lnwire.ChannelUpdate{ + chanUpdateAnn := &lnwire.ChannelUpdate1{ ShortChannelID: shortChanID, ChainHash: chainHash, Timestamp: uint32(time.Now().Unix()), @@ -4318,7 +4318,7 @@ func (f *Manager) newChanAnnouncement(localPubKey, // Finally, we'll generate the announcement proof which we'll use to // provide the other side with the necessary signatures required to // allow them to reconstruct the full channel announcement. - proof := &lnwire.AnnounceSignatures{ + proof := &lnwire.AnnounceSignatures1{ ChannelID: chanID, ShortChannelID: shortChanID, } diff --git a/funding/manager_test.go b/funding/manager_test.go index 9db175ec39..7dd91cccad 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -551,7 +551,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, OpenChannelPredicate: chainedAcceptor, NotifyPendingOpenChannelEvent: evt.NotifyPendingOpenChannelEvent, DeleteAliasEdge: func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgePolicy1, error) { return nil, nil }, @@ -1201,9 +1201,9 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, gotChannelUpdate := false for _, msg := range announcements { switch m := msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: gotChannelAnnouncement = true - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: // The channel update sent by the node should // advertise the MinHTLC value required by the @@ -1290,7 +1290,7 @@ func assertAnnouncementSignatures(t *testing.T, alice, bob *testNode) { gotNodeAnnouncement := false for _, msg := range announcements { switch msg.(type) { - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: gotAnnounceSignatures = true case *lnwire.NodeAnnouncement: gotNodeAnnouncement = true diff --git a/graph/ann_validation.go b/graph/ann_validation.go index 3936b4652f..3c93d06e52 100644 --- a/graph/ann_validation.go +++ b/graph/ann_validation.go @@ -2,89 +2,13 @@ package graph import ( "bytes" - "fmt" "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" ) -// ValidateChannelAnn validates the channel announcement message and checks -// that node signatures covers the announcement message, and that the bitcoin -// signatures covers the node keys. -func ValidateChannelAnn(a *lnwire.ChannelAnnouncement) error { - // First, we'll compute the digest (h) which is to be signed by each of - // the keys included within the node announcement message. This hash - // digest includes all the keys, so the (up to 4 signatures) will - // attest to the validity of each of the keys. - data, err := a.DataToSign() - if err != nil { - return err - } - dataHash := chainhash.DoubleHashB(data) - - // First we'll verify that the passed bitcoin key signature is indeed a - // signature over the computed hash digest. - bitcoinSig1, err := a.BitcoinSig1.ToSignature() - if err != nil { - return err - } - bitcoinKey1, err := btcec.ParsePubKey(a.BitcoinKey1[:]) - if err != nil { - return err - } - if !bitcoinSig1.Verify(dataHash, bitcoinKey1) { - return errors.New("can't verify first bitcoin signature") - } - - // If that checks out, then we'll verify that the second bitcoin - // signature is a valid signature of the bitcoin public key over hash - // digest as well. - bitcoinSig2, err := a.BitcoinSig2.ToSignature() - if err != nil { - return err - } - bitcoinKey2, err := btcec.ParsePubKey(a.BitcoinKey2[:]) - if err != nil { - return err - } - if !bitcoinSig2.Verify(dataHash, bitcoinKey2) { - return errors.New("can't verify second bitcoin signature") - } - - // Both node signatures attached should indeed be a valid signature - // over the selected digest of the channel announcement signature. - nodeSig1, err := a.NodeSig1.ToSignature() - if err != nil { - return err - } - nodeKey1, err := btcec.ParsePubKey(a.NodeID1[:]) - if err != nil { - return err - } - if !nodeSig1.Verify(dataHash, nodeKey1) { - return errors.New("can't verify data in first node signature") - } - - nodeSig2, err := a.NodeSig2.ToSignature() - if err != nil { - return err - } - nodeKey2, err := btcec.ParsePubKey(a.NodeID2[:]) - if err != nil { - return err - } - if !nodeSig2.Verify(dataHash, nodeKey2) { - return errors.New("can't verify data in second node signature") - } - - return nil - -} - // ValidateNodeAnn validates the node announcement by ensuring that the // attached signature is needed a signature of the node announcement under the // specified node public key. @@ -121,70 +45,3 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { return nil } - -// ValidateChannelUpdateAnn validates the channel update announcement by -// checking (1) that the included signature covers the announcement and has been -// signed by the node's private key, and (2) that the announcement's message -// flags and optional fields are sane. -func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, - a *lnwire.ChannelUpdate) error { - - if err := ValidateChannelUpdateFields(capacity, a); err != nil { - return err - } - - return VerifyChannelUpdateSignature(a, pubKey) -} - -// VerifyChannelUpdateSignature verifies that the channel update message was -// signed by the party with the given node public key. -func VerifyChannelUpdateSignature(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey) error { - - data, err := msg.DataToSign() - if err != nil { - return fmt.Errorf("unable to reconstruct message data: %w", err) - } - dataHash := chainhash.DoubleHashB(data) - - nodeSig, err := msg.Signature.ToSignature() - if err != nil { - return err - } - - if !nodeSig.Verify(dataHash, pubKey) { - return fmt.Errorf("invalid signature for channel update %v", - spew.Sdump(msg)) - } - - return nil -} - -// ValidateChannelUpdateFields validates a channel update's message flags and -// corresponding update fields. -func ValidateChannelUpdateFields(capacity btcutil.Amount, - msg *lnwire.ChannelUpdate) error { - - // The maxHTLC flag is mandatory. - if !msg.MessageFlags.HasMaxHtlc() { - return errors.Errorf("max htlc flag not set for channel "+ - "update %v", spew.Sdump(msg)) - } - - maxHtlc := msg.HtlcMaximumMsat - if maxHtlc == 0 || maxHtlc < msg.HtlcMinimumMsat { - return errors.Errorf("invalid max htlc for channel "+ - "update %v", spew.Sdump(msg)) - } - - // For light clients, the capacity will not be set so we'll skip - // checking whether the MaxHTLC value respects the channel's - // capacity. - capacityMsat := lnwire.NewMSatFromSatoshis(capacity) - if capacityMsat != 0 && maxHtlc > capacityMsat { - return errors.Errorf("max_htlc (%v) for channel update "+ - "greater than capacity (%v)", maxHtlc, capacityMsat) - } - - return nil -} diff --git a/graph/builder.go b/graph/builder.go index 6523b492bc..3726ccb70a 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -108,6 +108,11 @@ type Config struct { // IsAlias returns whether a passed ShortChannelID is an alias. This is // only used for our local channels. IsAlias func(scid lnwire.ShortChannelID) bool + + // FetchTxBySCID queries the chain for the transaction with the given + // SCID. A quit channel can be passed in to cancel the query. + FetchTxBySCID func(chanID *lnwire.ShortChannelID, quit chan struct{}) ( + *wire.MsgTx, error) } // Builder builds and maintains a view of the Lightning Network graph. @@ -146,7 +151,7 @@ type Builder struct { ntfnClientUpdates chan *topologyClientUpdate // channelEdgeMtx is a mutex we use to make sure we process only one - // ChannelEdgePolicy at a time for a given channelID, to ensure + // ChannelEdgePolicy1 at a time for a given channelID, to ensure // consistency between the various database accesses. channelEdgeMtx *multimutex.Mutex[uint64] @@ -482,7 +487,7 @@ func (b *Builder) syncGraphWithChain() error { // boolean is that of node 2, and the final boolean is true if the channel // is considered a zombie. func (b *Builder) isZombieChannel(e1, - e2 *models.ChannelEdgePolicy) (bool, bool, bool) { + e2 *models.ChannelEdgePolicy1) (bool, bool, bool) { chanExpiry := b.cfg.ChannelPruneExpiry @@ -538,15 +543,15 @@ func (b *Builder) pruneZombieChans() error { log.Infof("Examining channel graph for zombie channels") // A helper method to detect if the channel belongs to this node - isSelfChannelEdge := func(info *models.ChannelEdgeInfo) bool { + isSelfChannelEdge := func(info *models.ChannelEdgeInfo1) bool { return info.NodeKey1Bytes == b.cfg.SelfNode || info.NodeKey2Bytes == b.cfg.SelfNode } // First, we'll collect all the channels which are eligible for garbage // collection due to being zombies. - filterPruneChans := func(info *models.ChannelEdgeInfo, - e1, e2 *models.ChannelEdgePolicy) error { + filterPruneChans := func(info *models.ChannelEdgeInfo1, + e1, e2 *models.ChannelEdgePolicy1) error { // Exit early in case this channel is already marked to be // pruned @@ -1177,8 +1182,8 @@ func (b *Builder) processUpdate(msg interface{}, log.Tracef("Updated vertex data for node=%x", msg.PubKeyBytes) b.stats.incNumNodeUpdates() - case *models.ChannelEdgeInfo: - log.Debugf("Received ChannelEdgeInfo for channel %v", + case *models.ChannelEdgeInfo1: + log.Debugf("Received ChannelEdgeInfo1 for channel %v", msg.ChannelID) // Prior to processing the announcement we first check if we @@ -1227,7 +1232,7 @@ func (b *Builder) processUpdate(msg interface{}, // to obtain the full funding outpoint that's encoded within // the channel ID. channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) - fundingTx, err := b.fetchFundingTxWrapper(&channelID) + fundingTx, err := b.cfg.FetchTxBySCID(&channelID, b.quit) if err != nil { //nolint:lll // @@ -1351,8 +1356,8 @@ func (b *Builder) processUpdate(msg interface{}, "view: %v", err) } - case *models.ChannelEdgePolicy: - log.Debugf("Received ChannelEdgePolicy for channel %v", + case *models.ChannelEdgePolicy1: + log.Debugf("Received ChannelEdgePolicy1 for channel %v", msg.ChannelID) // We make sure to hold the mutex for this channel ID, @@ -1444,69 +1449,6 @@ func (b *Builder) processUpdate(msg interface{}, return nil } -// fetchFundingTxWrapper is a wrapper around fetchFundingTx, except that it -// will exit if the router has stopped. -func (b *Builder) fetchFundingTxWrapper(chanID *lnwire.ShortChannelID) ( - *wire.MsgTx, error) { - - txChan := make(chan *wire.MsgTx, 1) - errChan := make(chan error, 1) - - go func() { - tx, err := b.fetchFundingTx(chanID) - if err != nil { - errChan <- err - return - } - - txChan <- tx - }() - - select { - case tx := <-txChan: - return tx, nil - - case err := <-errChan: - return nil, err - - case <-b.quit: - return nil, ErrGraphBuilderShuttingDown - } -} - -// fetchFundingTx returns the funding transaction identified by the passed -// short channel ID. -// -// TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to -// later use getblocktxn). -func (b *Builder) fetchFundingTx( - chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { - - // First fetch the block hash by the block number encoded, then use - // that hash to fetch the block itself. - blockNum := int64(chanID.BlockHeight) - blockHash, err := b.cfg.Chain.GetBlockHash(blockNum) - if err != nil { - return nil, err - } - fundingBlock, err := b.cfg.Chain.GetBlock(blockHash) - if err != nil { - return nil, err - } - - // As a sanity check, ensure that the advertised transaction index is - // within the bounds of the total number of transactions within a - // block. - numTxns := uint32(len(fundingBlock.Transactions)) - if chanID.TxIndex > numTxns-1 { - return nil, fmt.Errorf("tx_index=#%v "+ - "is out of range (max_index=%v), network_chan_id=%v", - chanID.TxIndex, numTxns-1, chanID) - } - - return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil -} - // routingMsg couples a routing related routing topology update to the // error channel. type routingMsg struct { @@ -1517,7 +1459,7 @@ type routingMsg struct { // ApplyChannelUpdate validates a channel update and if valid, applies it to the // database. It returns a bool indicating whether the updates were successful. -func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate) bool { +func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { ch, _, _, err := b.GetChannelByID(msg.ShortChannelID) if err != nil { log.Errorf("Unable to retrieve channel by id: %v", err) @@ -1541,13 +1483,13 @@ func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate) bool { return false } - err = ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) + err = lnwire.ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg) if err != nil { log.Errorf("Unable to validate channel update: %v", err) return false } - err = b.UpdateEdge(&models.ChannelEdgePolicy{ + err = b.UpdateEdge(&models.ChannelEdgePolicy1{ SigBytes: msg.Signature.ToSignatureBytes(), ChannelID: msg.ShortChannelID.ToUint64(), LastUpdate: time.Unix(int64(msg.Timestamp), 0), @@ -1600,7 +1542,7 @@ func (b *Builder) AddNode(node *channeldb.LightningNode, // in construction of payment path. // // NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) AddEdge(edge *models.ChannelEdgeInfo, +func (b *Builder) AddEdge(edge *models.ChannelEdgeInfo1, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -1626,7 +1568,7 @@ func (b *Builder) AddEdge(edge *models.ChannelEdgeInfo, // considered as not fully constructed. // // NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy, +func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy1, op ...batch.SchedulerOption) error { rMsg := &routingMsg{ @@ -1667,9 +1609,9 @@ func (b *Builder) SyncedHeight() uint32 { // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) GetChannelByID(chanID lnwire.ShortChannelID) ( - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { return b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) } @@ -1702,12 +1644,12 @@ func (b *Builder) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1) error) error { return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, - func(tx kvdb.RTx, c *models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + func(tx kvdb.RTx, c *models.ChannelEdgeInfo1, + e *models.ChannelEdgePolicy1, + _ *models.ChannelEdgePolicy1) error { if e == nil { return fmt.Errorf("channel from self node " + @@ -1724,7 +1666,7 @@ func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) AddProof(chanID lnwire.ShortChannelID, - proof *models.ChannelAuthProof) error { + proof *models.ChannelAuthProof1) error { info, _, _, err := b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) if err != nil { diff --git a/graph/builder_test.go b/graph/builder_test.go index 600bd86344..f3deb02388 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -24,6 +24,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -63,7 +64,7 @@ func TestAddProof(t *testing.T) { ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // After utxo was recreated adding the edge without the proof. - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -146,7 +147,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -154,7 +155,7 @@ func TestIgnoreChannelEdgePolicyForUnknownChannel(t *testing.T) { BitcoinKey2Bytes: pub2, AuthProof: nil, } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -259,11 +260,11 @@ func TestWakeUpOnStaleBranch(t *testing.T) { node1 := createTestNode(t) node2 := createTestNode(t) - edge1 := &models.ChannelEdgeInfo{ + edge1 := &models.ChannelEdgeInfo1{ ChannelID: chanID1, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -277,11 +278,11 @@ func TestWakeUpOnStaleBranch(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge2 := &models.ChannelEdgeInfo{ + edge2 := &models.ChannelEdgeInfo1{ ChannelID: chanID2, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -353,6 +354,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { IsAlias: func(scid lnwire.ShortChannelID) bool { return false }, + FetchTxBySCID: newTxFetcher(ctx.chain), }) require.NoError(t, err) @@ -462,13 +464,13 @@ func TestDisconnectedBlocks(t *testing.T) { node1 := createTestNode(t) node2 := createTestNode(t) - edge1 := &models.ChannelEdgeInfo{ + edge1 := &models.ChannelEdgeInfo1{ ChannelID: chanID1, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, BitcoinKey1Bytes: node1.PubKeyBytes, BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -482,13 +484,13 @@ func TestDisconnectedBlocks(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge2 := &models.ChannelEdgeInfo{ + edge2 := &models.ChannelEdgeInfo1{ ChannelID: chanID2, NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, BitcoinKey1Bytes: node1.PubKeyBytes, BitcoinKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -612,11 +614,11 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { node1 := createTestNode(t) node2 := createTestNode(t) - edge1 := &models.ChannelEdgeInfo{ + edge1 := &models.ChannelEdgeInfo1{ ChannelID: chanID1.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1017,7 +1019,7 @@ func TestIsStaleNode(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -1093,7 +1095,7 @@ func TestIsKnownEdge(t *testing.T) { } ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -1149,7 +1151,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { t.Fatalf("router failed to detect fresh edge policy") } - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -1162,7 +1164,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { } // We'll also add two edge policies, one for each direction. - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, @@ -1176,7 +1178,7 @@ func TestIsStaleEdgePolicy(t *testing.T) { t.Fatalf("unable to update edge policy: %v", err) } - edgePolicy = &models.ChannelEdgePolicy{ + edgePolicy = &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: updateTimeStamp, @@ -1231,7 +1233,7 @@ const ( // newChannelEdgeInfo is a helper function used to create a new channel edge, // possibly skipping adding it to parts of the chain/state as well. func newChannelEdgeInfo(t *testing.T, ctx *testCtx, fundingHeight uint32, - ecm edgeCreationModifier) (*models.ChannelEdgeInfo, error) { + ecm edgeCreationModifier) (*models.ChannelEdgeInfo1, error) { node1 := createTestNode(t) node2 := createTestNode(t) @@ -1244,7 +1246,7 @@ func newChannelEdgeInfo(t *testing.T, ctx *testCtx, fundingHeight uint32, return nil, fmt.Errorf("unable to create edge: %w", err) } - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, @@ -1275,7 +1277,7 @@ func newChannelEdgeInfo(t *testing.T, ctx *testCtx, fundingHeight uint32, } func assertChanChainRejection(t *testing.T, ctx *testCtx, - edge *models.ChannelEdgeInfo, failCode errorCode) { + edge *models.ChannelEdgeInfo1, failCode errorCode) { t.Helper() @@ -1573,7 +1575,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( // We first insert the existence of the edge between the two // nodes. - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: edge.ChannelID, AuthProof: &testAuthProof, ChannelPoint: fundingPoint, @@ -1607,7 +1609,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( targetNode = edgeInfo.NodeKey2Bytes } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags( edge.MessageFlags, @@ -1945,7 +1947,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, // We first insert the existence of the edge between the two // nodes. - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: channelID, AuthProof: &testAuthProof, ChannelPoint: *fundingPoint, @@ -1987,7 +1989,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, channelFlags |= lnwire.ChanUpdateDisabled } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -2018,7 +2020,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } channelFlags |= lnwire.ChanUpdateDirection - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -2071,3 +2073,24 @@ func (m *mockLink) EligibleToForward() bool { func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { return m.mayAddOutgoingErr } + +func newTxFetcher(chain lnwallet.BlockChainIO) func( + chanID *lnwire.ShortChannelID, quit chan struct{}) (*wire.MsgTx, + error) { + + return func(chanID *lnwire.ShortChannelID, quit chan struct{}) ( + *wire.MsgTx, error) { + + blockNum := int64(chanID.BlockHeight) + blockHash, err := chain.GetBlockHash(blockNum) + if err != nil { + return nil, err + } + fundingBlock, err := chain.GetBlock(blockHash) + if err != nil { + return nil, err + } + + return fundingBlock.Transactions[chanID.TxIndex], nil + } +} diff --git a/graph/interfaces.go b/graph/interfaces.go index 7ae79f9a9f..43ed155aa3 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -29,17 +29,17 @@ type ChannelGraphSource interface { // AddEdge is used to add edge/channel to the topology of the router, // after all information about channel will be gathered this // edge/channel might be used in construction of payment path. - AddEdge(edge *models.ChannelEdgeInfo, + AddEdge(edge *models.ChannelEdgeInfo1, op ...batch.SchedulerOption) error // AddProof updates the channel edge info with proof which is needed to // properly announce the edge to the rest of the network. AddProof(chanID lnwire.ShortChannelID, - proof *models.ChannelAuthProof) error + proof *models.ChannelAuthProof1) error // UpdateEdge is used to update edge information, without this message // edge considered as not fully constructed. - UpdateEdge(policy *models.ChannelEdgePolicy, + UpdateEdge(policy *models.ChannelEdgePolicy1, op ...batch.SchedulerOption) error // IsStaleNode returns true if the graph source has a node announcement @@ -70,8 +70,8 @@ type ChannelGraphSource interface { // emanating from the "source" node which is the center of the // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, - c *models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy) error) error + c *models.ChannelEdgeInfo1, + e *models.ChannelEdgePolicy1) error) error // CurrentBlockHeight returns the block height from POV of the router // subsystem. @@ -79,8 +79,8 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) // FetchLightningNode attempts to look up a target node by its identity // public key. channeldb.ErrGraphNodeNotFound is returned if the node @@ -110,7 +110,7 @@ type DB interface { // slice of channels that have been closed by the target block are // returned if the function succeeds without error. PruneGraph(spentOutputs []*wire.OutPoint, blockHash *chainhash.Hash, - blockHeight uint32) ([]*models.ChannelEdgeInfo, error) + blockHeight uint32) ([]*models.ChannelEdgeInfo1, error) // ChannelView returns the verifiable edge information for each active // channel within the known channel graph. The set of UTXO's (along with @@ -169,7 +169,7 @@ type DB interface { // set to the last prune height valid for the remaining chain. // Channels that were removed from the graph resulting from the // disconnected block are returned. - DisconnectBlockAtHeight(height uint32) ([]*models.ChannelEdgeInfo, + DisconnectBlockAtHeight(height uint32) ([]*models.ChannelEdgeInfo1, error) // HasChannelEdge returns true if the database knows of a channel edge @@ -188,11 +188,11 @@ type DB interface { // either direction. // // ErrZombieEdge an be returned if the edge is currently marked as a - // zombie within the database. In this case, the ChannelEdgePolicy's - // will be nil, and the ChannelEdgeInfo will only include the public + // zombie within the database. In this case, the ChannelEdgePolicy1's + // will be nil, and the ChannelEdgeInfo1 will only include the public // keys of each node. - FetchChannelEdgesByID(chanID uint64) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + FetchChannelEdgesByID(chanID uint64) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) // AddLightningNode adds a vertex/node to the graph database. If the // node is not in the database from before, this will add a new, @@ -210,7 +210,7 @@ type DB interface { // and the set of features that the channel supports. The chanPoint and // chanID are used to uniquely identify the edge globally within the // database. - AddChannelEdge(edge *models.ChannelEdgeInfo, + AddChannelEdge(edge *models.ChannelEdgeInfo1, op ...batch.SchedulerOption) error // MarkEdgeZombie attempts to mark a channel identified by its channel @@ -220,13 +220,13 @@ type DB interface { // UpdateEdgePolicy updates the edge routing policy for a single // directed edge within the database for the referenced channel. The - // `flags` attribute within the ChannelEdgePolicy determines which of + // `flags` attribute within the ChannelEdgePolicy1 determines which of // the directed edges are being updated. If the flag is 1, then the // first node's information is being updated, otherwise it's the second // node's information. The node ordering is determined by the // lexicographical ordering of the identity public keys of the nodes on // either side of the channel. - UpdateEdgePolicy(edge *models.ChannelEdgePolicy, + UpdateEdgePolicy(edge *models.ChannelEdgePolicy1, op ...batch.SchedulerOption) error // HasLightningNode determines if the graph has a vertex identified by @@ -258,16 +258,16 @@ type DB interface { // // Unknown policies are passed into the callback as nil values. ForEachNodeChannel(nodePub route.Vertex, cb func(kvdb.RTx, - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error + *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1) error) error // UpdateChannelEdge retrieves and update edge of the graph database. // Method only reserved for updating an edge info after its already been // created. In order to maintain this constraints, we return an error in // the scenario that an edge info hasn't yet been created yet, but // someone attempts to update it. - UpdateChannelEdge(edge *models.ChannelEdgeInfo) error + UpdateChannelEdge(edge *models.ChannelEdgeInfo1) error // IsPublicNode is a helper method that determines whether the node with // the given public key is seen as a public node in the graph from the diff --git a/graph/notifications.go b/graph/notifications.go index 14ea3d127d..47e1c62f6c 100644 --- a/graph/notifications.go +++ b/graph/notifications.go @@ -211,7 +211,7 @@ type ClosedChanSummary struct { // createCloseSummaries takes in a slice of channels closed at the target block // height and creates a slice of summaries which of each channel closure. func createCloseSummaries(blockHeight uint32, - closedChans ...*models.ChannelEdgeInfo) []*ClosedChanSummary { + closedChans ...*models.ChannelEdgeInfo1) []*ClosedChanSummary { closeSummaries := make([]*ClosedChanSummary, len(closedChans)) for i, closedChan := range closedChans { @@ -337,12 +337,12 @@ func addToTopologyChange(graph DB, update *TopologyChange, // We ignore initial channel announcements as we'll only send out // updates once the individual edges themselves have been updated. - case *models.ChannelEdgeInfo: + case *models.ChannelEdgeInfo1: return nil // Any new ChannelUpdateAnnouncements will generate a corresponding // ChannelEdgeUpdate notification. - case *models.ChannelEdgePolicy: + case *models.ChannelEdgePolicy1: // We'll need to fetch the edge's information from the database // in order to get the information concerning which nodes are // being connected. diff --git a/graph/notifications_test.go b/graph/notifications_test.go index 09ebf1211b..3d2fae0a0c 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -69,7 +69,7 @@ var ( _ = testSScalar.SetByteSlice(testSBytes) testSig = ecdsa.NewSignature(testRScalar, testSScalar) - testAuthProof = models.ChannelAuthProof{ + testAuthProof = models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -99,7 +99,7 @@ func createTestNode(t *testing.T) *channeldb.LightningNode { } func randEdgePolicy(chanID *lnwire.ShortChannelID, - node *channeldb.LightningNode) (*models.ChannelEdgePolicy, error) { + node *channeldb.LightningNode) (*models.ChannelEdgePolicy1, error) { InboundFee := models.InboundFee{ Base: prand.Int31() * -1, @@ -112,7 +112,7 @@ func randEdgePolicy(chanID *lnwire.ShortChannelID, return nil, err } - return &models.ChannelEdgePolicy{ + return &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: chanID.ToUint64(), LastUpdate: time.Unix(int64(prand.Int31()), 0), @@ -466,11 +466,11 @@ func TestEdgeUpdateNotification(t *testing.T) { // Finally, to conclude our test set up, we'll create a channel // update to announce the created channel between the two nodes. - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -507,7 +507,7 @@ func TestEdgeUpdateNotification(t *testing.T) { } assertEdgeCorrect := func(t *testing.T, edgeUpdate *ChannelEdgeUpdate, - edgeAnn *models.ChannelEdgePolicy) { + edgeAnn *models.ChannelEdgePolicy1) { if edgeUpdate.ChanID != edgeAnn.ChannelID { t.Fatalf("channel ID of edge doesn't match: "+ @@ -653,11 +653,11 @@ func TestNodeUpdateNotification(t *testing.T) { testFeaturesBuf := new(bytes.Buffer) require.NoError(t, testFeatures.Encode(testFeaturesBuf)) - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -837,11 +837,11 @@ func TestNotificationCancellation(t *testing.T) { // to the client. ntfnClient.Cancel() - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -906,11 +906,11 @@ func TestChannelCloseNotification(t *testing.T) { // Finally, to conclude our test set up, we'll create a channel // announcement to announce the created channel between the two nodes. - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: node1.PubKeyBytes, NodeKey2Bytes: node2.PubKeyBytes, - AuthProof: &models.ChannelAuthProof{ + AuthProof: &models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -1077,6 +1077,7 @@ func (c *testCtx) RestartBuilder(t *testing.T) { IsAlias: func(scid lnwire.ShortChannelID) bool { return false }, + FetchTxBySCID: newTxFetcher(c.chain), }) require.NoError(t, err) require.NoError(t, builder.Start()) @@ -1175,6 +1176,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, IsAlias: func(scid lnwire.ShortChannelID) bool { return false }, + FetchTxBySCID: newTxFetcher(chain), }) require.NoError(t, err) require.NoError(t, graphBuilder.Start()) diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index 2f3c8c02ce..a7c4645001 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -102,7 +102,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // ChannelUpdates for the same channel, or NodeAnnouncements of nodes // that are involved in this channel. This goes for both the wire // type,s and also the types that we use within the database. - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: // We ensure that we only create a new announcement signal iff, // one doesn't already exist, as there may be duplicate @@ -126,7 +126,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { v.nodeAnnDependencies[route.Vertex(msg.NodeID1)] = signals v.nodeAnnDependencies[route.Vertex(msg.NodeID2)] = signals } - case *models.ChannelEdgeInfo: + case *models.ChannelEdgeInfo1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) if _, ok := v.chanAnnFinSignal[shortID]; !ok { @@ -144,16 +144,16 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // These other types don't have any dependants, so no further // initialization needs to be done beyond just occupying a job slot. - case *models.ChannelEdgePolicy: + case *models.ChannelEdgePolicy1: return - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return case *lnwire.NodeAnnouncement: // TODO(roasbeef): node ann needs to wait on existing channel updates return case *channeldb.LightningNode: return - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: // TODO(roasbeef): need to wait on chan ann? return } @@ -188,11 +188,11 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { switch msg := job.(type) { // Any ChannelUpdate or NodeAnnouncement jobs will need to wait on the // completion of any active ChannelAnnouncement jobs related to them. - case *models.ChannelEdgePolicy: + case *models.ChannelEdgePolicy1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) signals, ok = v.chanEdgeDependencies[shortID] - jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy, scid=%v", + jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy1, scid=%v", msg.ChannelID) case *channeldb.LightningNode: @@ -202,7 +202,7 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { jobDesc = fmt.Sprintf("job=channeldb.LightningNode, pub=%s", vertex) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: signals, ok = v.chanEdgeDependencies[msg.ShortChannelID] jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate, scid=%v", @@ -216,10 +216,10 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { // Other types of jobs can be executed immediately, so we'll just // return directly. - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: // TODO(roasbeef): need to wait on chan ann? - case *models.ChannelEdgeInfo: - case *lnwire.ChannelAnnouncement: + case *models.ChannelEdgeInfo1: + case *lnwire.ChannelAnnouncement1: } // Release the lock once the above read is finished. @@ -264,7 +264,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { // If we've just finished executing a ChannelAnnouncement, then we'll // close out the signal, and remove the signal from the map of active // ones. This will allow/deny any dependent jobs to continue execution. - case *models.ChannelEdgeInfo: + case *models.ChannelEdgeInfo1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) finSignals, ok := v.chanAnnFinSignal[shortID] if ok { @@ -275,7 +275,7 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { } delete(v.chanAnnFinSignal, shortID) } - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: finSignals, ok := v.chanAnnFinSignal[msg.ShortChannelID] if ok { if allow { @@ -295,13 +295,13 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes)) case *lnwire.NodeAnnouncement: delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: delete(v.chanEdgeDependencies, msg.ShortChannelID) - case *models.ChannelEdgePolicy: + case *models.ChannelEdgePolicy1: shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) delete(v.chanEdgeDependencies, shortID) - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: return } } diff --git a/graph/validation_barrier_test.go b/graph/validation_barrier_test.go index da404443f5..38fc7a0870 100644 --- a/graph/validation_barrier_test.go +++ b/graph/validation_barrier_test.go @@ -73,9 +73,9 @@ func TestValidationBarrierQuit(t *testing.T) { // Create a set of unique channel announcements that we will prep for // validation. - anns := make([]*lnwire.ChannelAnnouncement, 0, numTasks) + anns := make([]*lnwire.ChannelAnnouncement1, 0, numTasks) for i := 0; i < numTasks; i++ { - anns = append(anns, &lnwire.ChannelAnnouncement{ + anns = append(anns, &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), NodeID1: nodeIDFromInt(uint64(2 * i)), NodeID2: nodeIDFromInt(uint64(2*i + 1)), @@ -85,9 +85,9 @@ func TestValidationBarrierQuit(t *testing.T) { // Create a set of channel updates, that must wait until their // associated channel announcement has been verified. - chanUpds := make([]*lnwire.ChannelUpdate, 0, numTasks) + chanUpds := make([]*lnwire.ChannelUpdate1, 0, numTasks) for i := 0; i < numTasks; i++ { - chanUpds = append(chanUpds, &lnwire.ChannelUpdate{ + chanUpds = append(chanUpds, &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(i)), }) barrier.InitJobDependencies(chanUpds[i]) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 1311373a17..9459b2fc36 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -85,7 +85,7 @@ type scidAliasHandler interface { // HTLCs on option_scid_alias channels. attachFailAliasUpdate(failClosure func( sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate) + incoming bool) *lnwire.ChannelUpdate1) // getAliases fetches the link's underlying aliases. This is used by // the Switch to determine whether to forward an HTLC and where to diff --git a/htlcswitch/link.go b/htlcswitch/link.go index ff140352c0..2b62f2b3bc 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -119,7 +119,8 @@ type ChannelLinkConfig struct { // specified when we receive an incoming HTLC. This will be used to // provide payment senders our latest policy when sending encrypted // error messages. - FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + FetchLastChannelUpdate func(lnwire.ShortChannelID) ( + *lnwire.ChannelUpdate1, error) // Peer is a lightning network node with which we have the channel link // opened. @@ -261,7 +262,7 @@ type ChannelLinkConfig struct { // FailAliasUpdate is a function used to fail an HTLC for an // option_scid_alias channel. FailAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate + incoming bool) *lnwire.ChannelUpdate1 // GetAliases is used by the link and switch to fetch the set of // aliases for a given link. @@ -763,7 +764,7 @@ func shouldAdjustCommitFee(netFee, chanFee, } // failCb is used to cut down on the argument verbosity. -type failCb func(update *lnwire.ChannelUpdate) lnwire.FailureMessage +type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage // createFailureWithUpdate creates a ChannelUpdate when failing an incoming or // outgoing HTLC. It may return a FailureMessage that references a channel's @@ -2962,7 +2963,7 @@ func (l *channelLink) getAliases() []lnwire.ShortChannelID { // // Part of the scidAliasHandler interface. func (l *channelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { l.Lock() l.cfg.FailAliasUpdate = closure @@ -3054,7 +3055,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewFeeInsufficient(amtToForward, *upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3082,7 +3083,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Grab the latest routing policy so the sending node is up to // date with our current policy. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewIncorrectCltvExpiry( incomingTimeout, *upd, ) @@ -3131,7 +3132,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewAmountBelowMinimum(amt, *upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3146,7 +3147,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up-to-date data. - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3161,7 +3162,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, "outgoing_expiry=%v, best_height=%v", payHash[:], timeout, heightNow) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewExpiryTooSoon(*upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3181,7 +3182,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, if amt > l.Bandwidth() { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewTemporaryChannelFailure(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3694,7 +3695,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, l.log.Errorf("unable to encode the "+ "remaining route %v", err) - cb := func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { + cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { //nolint:lll return lnwire.NewTemporaryChannelFailure(upd) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 8343d5a1c2..d369d7862e 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6164,13 +6164,13 @@ func TestForwardingAsymmetricTimeLockPolicies(t *testing.T) { // forwarding policy. func TestCheckHtlcForward(t *testing.T) { fetchLastChannelUpdate := func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { - return &lnwire.ChannelUpdate{}, nil + return &lnwire.ChannelUpdate1{}, nil } failAliasUpdate := func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate { + incoming bool) *lnwire.ChannelUpdate1 { return nil } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6a9628f940..a2df0851ed 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -166,7 +166,7 @@ type mockServer struct { var _ lnpeer.Peer = (*mockServer)(nil) func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { - signAliasUpdate := func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { return testSig, nil @@ -182,9 +182,9 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) events: make(map[time.Time]channeldb.ForwardingEvent), }, FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { - return &lnwire.ChannelUpdate{ + return &lnwire.ChannelUpdate1{ ShortChannelID: scid, }, nil }, @@ -732,7 +732,7 @@ type mockChannelLink struct { checkHtlcForwardResult *LinkError failAliasUpdate func(sid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate + incoming bool) *lnwire.ChannelUpdate1 confirmedZC bool } @@ -867,7 +867,7 @@ func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { } func (f *mockChannelLink) attachFailAliasUpdate(closure func( - sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate) { + sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { f.failAliasUpdate = closure } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index efd469f785..2c7a628745 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -173,7 +173,8 @@ type Config struct { // specified when we receive an incoming HTLC. This will be used to // provide payment senders our latest policy when sending encrypted // error messages. - FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + FetchLastChannelUpdate func(lnwire.ShortChannelID) ( + *lnwire.ChannelUpdate1, error) // Notifier is an instance of a chain notifier that we'll use to signal // the switch when a new block has arrived. @@ -220,7 +221,7 @@ type Config struct { // option_scid_alias channels. This avoids a potential privacy leak by // replacing the public, confirmed SCID with the alias in the // ChannelUpdate. - SignAliasUpdate func(u *lnwire.ChannelUpdate) (*ecdsa.Signature, + SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) // IsAlias returns whether or not a given SCID is an alias. @@ -2615,7 +2616,7 @@ func (s *Switch) failMailboxUpdate(outgoingScid, // and the caller is expected to handle this properly. In this case, a return // to the original non-alias behavior is expected. func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID, - incoming bool) *lnwire.ChannelUpdate { + incoming bool) *lnwire.ChannelUpdate1 { // This function does not defer the unlocking because of the database // lookups for ChannelUpdate. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 0bc0df2d46..825ee6c652 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3951,7 +3951,7 @@ func TestSwitchHoldForward(t *testing.T) { // Simulate an error during the composition of the failure message. currentCallback := c.s.cfg.FetchLastChannelUpdate c.s.cfg.FetchLastChannelUpdate = func( - lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { return nil, errors.New("cannot fetch update") } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index a71577ef2b..abd48e806d 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -91,8 +91,10 @@ func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID, // mockGetChanUpdateMessage helper function which returns topology update of // the channel -func mockGetChanUpdateMessage(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - return &lnwire.ChannelUpdate{ +func mockGetChanUpdateMessage(_ lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, + error) { + + return &lnwire.ChannelUpdate1{ Signature: wireSig, }, nil } diff --git a/lnrpc/devrpc/dev_server.go b/lnrpc/devrpc/dev_server.go index 662c0d08d9..97384a6797 100644 --- a/lnrpc/devrpc/dev_server.go +++ b/lnrpc/devrpc/dev_server.go @@ -262,7 +262,7 @@ func (s *Server) ImportGraph(ctx context.Context, for _, rpcEdge := range graph.Edges { rpcEdge := rpcEdge - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: rpcEdge.ChannelId, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, Capacity: btcutil.Amount(rpcEdge.Capacity), @@ -289,8 +289,8 @@ func (s *Server) ImportGraph(ctx context.Context, rpcEdge.ChanPoint, err) } - makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *models.ChannelEdgePolicy { //nolint:lll - policy := &models.ChannelEdgePolicy{ + makePolicy := func(rpcPolicy *lnrpc.RoutingPolicy) *models.ChannelEdgePolicy1 { //nolint:lll + policy := &models.ChannelEdgePolicy1{ ChannelID: rpcEdge.ChannelId, LastUpdate: time.Unix( int64(rpcPolicy.LastUpdate), 0, diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index dcb1bef71e..180c81f68e 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -624,7 +624,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // chanCanBeHopHint returns true if the target channel is eligible to be a hop // hint. func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( - *models.ChannelEdgePolicy, bool) { + *models.ChannelEdgePolicy1, bool) { // Since we're only interested in our private channels, we'll skip // public ones. @@ -679,7 +679,7 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( // Now, we'll need to determine which is the correct policy for HTLCs // being sent from the remote node. - var remotePolicy *models.ChannelEdgePolicy + var remotePolicy *models.ChannelEdgePolicy1 if bytes.Equal(remotePub[:], info.NodeKey1Bytes[:]) { remotePolicy = p1 } else { @@ -737,9 +737,9 @@ func newHopHintInfo(c *channeldb.OpenChannel, isActive bool) *HopHintInfo { } // newHopHint returns a new hop hint using the relevant data from a hopHintInfo -// and a ChannelEdgePolicy. +// and a ChannelEdgePolicy1. func newHopHint(hopHintInfo *HopHintInfo, - chanPolicy *models.ChannelEdgePolicy) zpay32.HopHint { + chanPolicy *models.ChannelEdgePolicy1) zpay32.HopHint { return zpay32.HopHint{ NodeID: hopHintInfo.RemotePubkey, @@ -762,8 +762,8 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, + FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) // GetAlias allows the peer's alias SCID to be retrieved for private diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 76a529f8c6..2bb94da8b0 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -67,8 +67,8 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { args := h.Mock.Called(chanID) @@ -80,13 +80,13 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( return nil, nil, nil, err } - edgeInfo, ok := args.Get(0).(*models.ChannelEdgeInfo) + edgeInfo, ok := args.Get(0).(*models.ChannelEdgeInfo1) require.True(h.t, ok) - policy1, ok := args.Get(1).(*models.ChannelEdgePolicy) + policy1, ok := args.Get(1).(*models.ChannelEdgePolicy1) require.True(h.t, ok) - policy2, ok := args.Get(2).(*models.ChannelEdgePolicy) + policy2, ok := args.Get(2).(*models.ChannelEdgePolicy1) require.True(h.t, ok) return edgeInfo, policy1, policy2, err @@ -226,9 +226,9 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) h.Mock.On( @@ -264,9 +264,9 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) alias := lnwire.ShortChannelID{TxPosition: 5} h.Mock.On( @@ -305,15 +305,15 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{ + &models.ChannelEdgeInfo1{ NodeKey1Bytes: selectedPolicy, }, - &models.ChannelEdgePolicy{ + &models.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, }, - &models.ChannelEdgePolicy{}, + &models.ChannelEdgePolicy1{}, nil, ) }, @@ -353,9 +353,9 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{ + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -398,9 +398,9 @@ var shouldIncludeChannelTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{ + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{ FeeBaseMSat: 1000, FeeProportionalMillionths: 20, TimeLockDelta: 13, @@ -565,9 +565,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, @@ -615,9 +615,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -666,9 +666,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, @@ -699,9 +699,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) // Prepare the mock for the second channel. @@ -716,9 +716,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -753,9 +753,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) // Prepare the mock for the second channel. @@ -770,9 +770,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 10, @@ -808,9 +808,9 @@ var populateHopHintsTestCases = []struct { h.Mock.On( "FetchChannelEdgesByID", mock.Anything, ).Once().Return( - &models.ChannelEdgeInfo{}, - &models.ChannelEdgePolicy{}, - &models.ChannelEdgePolicy{}, nil, + &models.ChannelEdgeInfo1{}, + &models.ChannelEdgePolicy1{}, + &models.ChannelEdgePolicy1{}, nil, ) }, maxHopHints: 1, diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index ac493ba330..b37df7218f 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1605,7 +1605,7 @@ func marshallWireError(msg lnwire.FailureMessage, // marshallChannelUpdate marshalls a channel update as received over the wire to // the router rpc format. -func marshallChannelUpdate(update *lnwire.ChannelUpdate) *lnrpc.ChannelUpdate { +func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { if update == nil { return nil } diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 49b6106119..a2bc21f39d 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -5,11 +5,11 @@ import ( "io" ) -// AnnounceSignatures is a direct message between two endpoints of a +// AnnounceSignatures1 is a direct message between two endpoints of a // channel and serves as an opt-in mechanism to allow the announcement of // the channel to the rest of the network. It contains the necessary // signatures by the sender to construct the channel announcement message. -type AnnounceSignatures struct { +type AnnounceSignatures1 struct { // ChannelID is the unique description of the funding transaction. // Channel id is better for users and debugging and short channel id is // used for quick test on existence of the particular utxo inside the @@ -43,15 +43,19 @@ type AnnounceSignatures struct { ExtraOpaqueData ExtraOpaqueData } -// A compile time check to ensure AnnounceSignatures implements the +// A compile time check to ensure AnnounceSignatures1 implements the // lnwire.Message interface. -var _ Message = (*AnnounceSignatures)(nil) +var _ Message = (*AnnounceSignatures1)(nil) -// Decode deserializes a serialized AnnounceSignatures stored in the passed +// A compile time check to ensure AnnounceSignatures1 implements the +// lnwire.AnnounceSignatures interface. +var _ AnnounceSignatures = (*AnnounceSignatures1)(nil) + +// Decode deserializes a serialized AnnounceSignatures1 stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { +func (a *AnnounceSignatures1) Decode(r io.Reader, _ uint32) error { return ReadElements(r, &a.ChannelID, &a.ShortChannelID, @@ -61,11 +65,11 @@ func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { ) } -// Encode serializes the target AnnounceSignatures into the passed io.Writer +// Encode serializes the target AnnounceSignatures1 into the passed io.Writer // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *AnnounceSignatures) Encode(w *bytes.Buffer, pver uint32) error { +func (a *AnnounceSignatures1) Encode(w *bytes.Buffer, _ uint32) error { if err := WriteChannelID(w, a.ChannelID); err != nil { return err } @@ -89,6 +93,20 @@ func (a *AnnounceSignatures) Encode(w *bytes.Buffer, pver uint32) error { // wire. // // This is part of the lnwire.Message interface. -func (a *AnnounceSignatures) MsgType() MessageType { +func (a *AnnounceSignatures1) MsgType() MessageType { return MsgAnnounceSignatures } + +// SCID returns the ShortChannelID of the channel. +// +// This is part of the lnwire.AnnounceSignatures interface. +func (a *AnnounceSignatures1) SCID() ShortChannelID { + return a.ShortChannelID +} + +// ChanID returns the ChannelID identifying the channel. +// +// This is part of the lnwire.AnnounceSignatures interface. +func (a *AnnounceSignatures1) ChanID() ChannelID { + return a.ChannelID +} diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go new file mode 100644 index 0000000000..e0fd9d3e1a --- /dev/null +++ b/lnwire/announcement_signatures_2.go @@ -0,0 +1,97 @@ +package lnwire + +import ( + "bytes" + "io" +) + +// AnnounceSignatures2 is a direct message between two endpoints of a +// channel and serves as an opt-in mechanism to allow the announcement of +// a taproot channel to the rest of the network. It contains the necessary +// signatures by the sender to construct the channel_announcement_2 message. +type AnnounceSignatures2 struct { + // ChannelID is the unique description of the funding transaction. + // Channel id is better for users and debugging and short channel id is + // used for quick test on existence of the particular utxo inside the + // blockchain, because it contains information about block. + ChannelID ChannelID + + // ShortChannelID is the unique description of the funding transaction. + // It is constructed with the most significant 3 bytes as the block + // height, the next 3 bytes indicating the transaction index within the + // block, and the least significant two bytes indicating the output + // index which pays to the channel. + ShortChannelID ShortChannelID + + // PartialSignature is the combination of the partial Schnorr signature + // created for the node's bitcoin key with the partial signature created + // for the node's node ID key. + PartialSignature PartialSig + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData ExtraOpaqueData +} + +// A compile time check to ensure AnnounceSignatures2 implements the +// lnwire.Message interface. +var _ Message = (*AnnounceSignatures2)(nil) + +// Decode deserializes a serialized AnnounceSignatures2 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { + return ReadElements(r, + &a.ChannelID, + &a.ShortChannelID, + &a.PartialSignature, + &a.ExtraOpaqueData, + ) +} + +// Encode serializes the target AnnounceSignatures2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { + if err := WriteChannelID(w, a.ChannelID); err != nil { + return err + } + + if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + return err + } + + if err := WritePartialSig(w, a.PartialSignature); err != nil { + return err + } + + return WriteBytes(w, a.ExtraOpaqueData) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures2) MsgType() MessageType { + return MsgAnnounceSignatures2 +} + +// SCID returns the ShortChannelID of the channel. +// +// NOTE: this is part of the AnnounceSignatures interface. +func (a *AnnounceSignatures2) SCID() ShortChannelID { + return a.ShortChannelID +} + +// ChanID returns the ChannelID identifying the channel. +// +// NOTE: this is part of the AnnounceSignatures interface. +func (a *AnnounceSignatures2) ChanID() ChannelID { + return a.ChannelID +} diff --git a/lnwire/boolean.go b/lnwire/boolean.go new file mode 100644 index 0000000000..0b70603973 --- /dev/null +++ b/lnwire/boolean.go @@ -0,0 +1,80 @@ +package lnwire + +import ( + "errors" + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// Boolean wraps a boolean in a struct to help it satisfy the tlv.RecordProducer +// interface. If a boolean tlv record is not present, this has a meaning of +// false. If it is present but has a length of 0, then this means true. +// Otherwise, if it is present but has a length of 1 then the value has been +// encoded explicitly. +type Boolean struct { + B bool +} + +// Record returns the tlv record for the boolean entry. +func (b *Boolean) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, &b.B, b.size, booleanEncoder, booleanDecoder, + ) +} + +// size returns the number of bytes required to encode the Boolean. If the +// underlying bool is true, then we will have a zero length tlv record, +// otherwise we will have a 1 byte record. +func (b *Boolean) size() uint64 { + if b.B { + return 0 + } + + return 1 +} + +func booleanEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*bool); ok { + // If the underlying value is true, then we can just make the + // tlv zero value as that implies true. + if *v { + return nil + } + + // If it is false, then we encode it explicitly. + buf[0] = 0 + _, err := w.Write(buf[:1]) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "bool") +} + +func booleanDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*bool); ok && (l == 0 || l == 1) { + // If the length is zero, then the value is true. + if l == 0 { + *v = true + + return nil + } + + // Else, the length is 1 and the value will have been encoded + // explicitly. + if _, err := io.ReadFull(r, buf[:1]); err != nil { + return err + } + if buf[0] != 0 && buf[0] != 1 { + return errors.New("corrupted data") + } + *v = buf[0] != 0 + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "bool") +} diff --git a/lnwire/boolean_test.go b/lnwire/boolean_test.go new file mode 100644 index 0000000000..490a005123 --- /dev/null +++ b/lnwire/boolean_test.go @@ -0,0 +1,133 @@ +package lnwire + +import ( + "bytes" + "io" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestBooleanRecord tests the encoding and decoding of a boolean tlv record. +func TestBooleanRecord(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeFn func(w *bytes.Buffer) error + expectedBool bool + }{ + { + name: "omitted boolean record", + encodeFn: encodedWireMsgOmitBool, + expectedBool: false, + }, + { + name: "zero length tlv", + encodeFn: encodedWireMsgZeroLenTrue, + expectedBool: true, + }, + { + name: "explicitly encoded false", + encodeFn: encodedWireMsgExplicitFalse, + expectedBool: false, + }, + { + name: "explicitly encoded true", + encodeFn: encodedWireMsgExplicitTrue, + expectedBool: true, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, test.encodeFn(&buf)) + + msg := &wireMsg{} + require.NoError(t, msg.decodeWireMsg(&buf)) + + require.Equal( + t, test.expectedBool, msg.DisableFlag.Val.B, + ) + }) + } +} + +type wireMsg struct { + DisableFlag tlv.RecordT[tlv.TlvType2, Boolean] + + ExtraOpaqueData ExtraOpaqueData +} + +func encodedWireMsgExplicitFalse(w *bytes.Buffer) error { + disableFlag := tlv.ZeroRecordT[tlv.TlvType2, Boolean]() + + var b ExtraOpaqueData + err := EncodeMessageExtraData(&b, &disableFlag) + if err != nil { + return err + } + + return WriteBytes(w, b) +} + +func encodedWireMsgExplicitTrue(w *bytes.Buffer) error { + disableFlag := tlv.ZeroRecordT[tlv.TlvType2, bool]() + disableFlag.Val = true + + var b ExtraOpaqueData + err := EncodeMessageExtraData(&b, &disableFlag) + if err != nil { + return err + } + + return WriteBytes(w, b) +} + +func encodedWireMsgZeroLenTrue(w *bytes.Buffer) error { + disableFlag := tlv.ZeroRecordT[tlv.TlvType2, Boolean]() + disableFlag.Val.B = true + + var b ExtraOpaqueData + err := EncodeMessageExtraData(&b, &disableFlag) + if err != nil { + return err + } + + return WriteBytes(w, b) +} + +func encodedWireMsgOmitBool(w *bytes.Buffer) error { + var b ExtraOpaqueData + err := EncodeMessageExtraData(&b) + if err != nil { + return err + } + + return WriteBytes(w, b) +} + +func (m *wireMsg) decodeWireMsg(r io.Reader) error { + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + disableFlag := tlv.ZeroRecordT[tlv.TlvType2, Boolean]() + + typeMap, err := tlvRecords.ExtractRecords(&disableFlag) + if err != nil { + return err + } + + if _, ok := typeMap[m.DisableFlag.TlvType()]; ok { + m.DisableFlag = disableFlag + } + + return nil +} diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 2b34c0f990..2396035feb 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -2,15 +2,17 @@ package lnwire import ( "bytes" + "fmt" "io" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg/chainhash" ) -// ChannelAnnouncement message is used to announce the existence of a channel +// ChannelAnnouncement1 message is used to announce the existence of a channel // between two peers in the overlay, which is propagated by the discovery // service over broadcast handler. -type ChannelAnnouncement struct { +type ChannelAnnouncement1 struct { // This signatures are used by nodes in order to create cross // references between node's channel and node. Requiring both nodes // to sign indicates they are both willing to route other payments via @@ -60,13 +62,13 @@ type ChannelAnnouncement struct { // A compile time check to ensure ChannelAnnouncement implements the // lnwire.Message interface. -var _ Message = (*ChannelAnnouncement)(nil) +var _ Message = (*ChannelAnnouncement1)(nil) // Decode deserializes a serialized ChannelAnnouncement stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { +func (a *ChannelAnnouncement1) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &a.NodeSig1, &a.NodeSig2, @@ -87,7 +89,7 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { +func (a *ChannelAnnouncement1) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteSig(w, a.NodeSig1); err != nil { return err } @@ -139,13 +141,13 @@ func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { // wire. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement) MsgType() MessageType { +func (a *ChannelAnnouncement1) MsgType() MessageType { return MsgChannelAnnouncement } // DataToSign is used to retrieve part of the announcement message which should // be signed. -func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { +func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) { // We should not include the signatures itself. b := make([]byte, 0, MaxMsgBody) buf := bytes.NewBuffer(b) @@ -184,3 +186,114 @@ func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// Validate validates the channel announcement message and checks that node +// signatures covers the announcement message, and that the bitcoin signatures +// covers the node keys. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) Validate(_ func(id *ShortChannelID) ( + []byte, error)) error { + + // First, we'll compute the digest (h) which is to be signed by each of + // the keys included within the node announcement message. This hash + // digest includes all the keys, so the (up to 4 signatures) will + // attest to the validity of each of the keys. + data, err := a.DataToSign() + if err != nil { + return err + } + dataHash := chainhash.DoubleHashB(data) + + // First we'll verify that the passed bitcoin key signature is indeed a + // signature over the computed hash digest. + bitcoinSig1, err := a.BitcoinSig1.ToSignature() + if err != nil { + return err + } + bitcoinKey1, err := btcec.ParsePubKey(a.BitcoinKey1[:]) + if err != nil { + return err + } + if !bitcoinSig1.Verify(dataHash, bitcoinKey1) { + return fmt.Errorf("can't verify first bitcoin signature") + } + + // If that checks out, then we'll verify that the second bitcoin + // signature is a valid signature of the bitcoin public key over hash + // digest as well. + bitcoinSig2, err := a.BitcoinSig2.ToSignature() + if err != nil { + return err + } + bitcoinKey2, err := btcec.ParsePubKey(a.BitcoinKey2[:]) + if err != nil { + return err + } + if !bitcoinSig2.Verify(dataHash, bitcoinKey2) { + return fmt.Errorf("can't verify second bitcoin signature") + } + + // Both node signatures attached should indeed be a valid signature + // over the selected digest of the channel announcement signature. + nodeSig1, err := a.NodeSig1.ToSignature() + if err != nil { + return err + } + nodeKey1, err := btcec.ParsePubKey(a.NodeID1[:]) + if err != nil { + return err + } + if !nodeSig1.Verify(dataHash, nodeKey1) { + return fmt.Errorf("can't verify data in first node signature") + } + + nodeSig2, err := a.NodeSig2.ToSignature() + if err != nil { + return err + } + nodeKey2, err := btcec.ParsePubKey(a.NodeID2[:]) + if err != nil { + return err + } + if !nodeSig2.Verify(dataHash, nodeKey2) { + return fmt.Errorf("can't verify data in second node signature") + } + + return nil +} + +// Node1KeyBytes returns the bytes representing the public key of node 1 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) Node1KeyBytes() [33]byte { + return a.NodeID1 +} + +// Node2KeyBytes returns the bytes representing the public key of node 2 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) Node2KeyBytes() [33]byte { + return a.NodeID2 +} + +// GetChainHash returns the hash of the chain which this channel's funding +// transaction is confirmed in. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) GetChainHash() chainhash.Hash { + return a.ChainHash +} + +// SCID returns the short channel ID of the channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (a *ChannelAnnouncement1) SCID() ShortChannelID { + return a.ShortChannelID +} + +// A compile-time check to ensure that ChannelAnnouncement1 implements the +// ChannelAnnouncement interface. +var _ ChannelAnnouncement = (*ChannelAnnouncement1)(nil) diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go new file mode 100644 index 0000000000..6207881523 --- /dev/null +++ b/lnwire/channel_announcement_2.go @@ -0,0 +1,344 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // chanAnn2MsgName is a string representing the name of the + // ChannelAnnouncement2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelAnnouncement2 message. + chanAnn2MsgName = "channel_announcement_2" + + // chanAnn2SigFieldName is the name of the signature field of the + // ChannelAnnouncement2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelAnnouncement2 message. + chanAnn2SigFieldName = "announcement_signature" +) + +// ChannelAnnouncement2 message is used to announce the existence of a taproot +// channel between two peers in the network. +type ChannelAnnouncement2 struct { + // Signature is a Schnorr signature over the TLV stream of the message. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] + + // Features is the feature vector that encodes the features supported + // by the target node. This field can be used to signal the type of the + // channel, or modifications to the fields that would normally follow + // this vector. + Features tlv.RecordT[tlv.TlvType2, RawFeatureVector] + + // ShortChannelID is the unique description of the funding transaction, + // or where exactly it's located within the target blockchain. + ShortChannelID tlv.RecordT[tlv.TlvType4, ShortChannelID] + + // Capacity is the number of satoshis of the capacity of this channel. + // It must be less than or equal to the value of the on-chain funding + // output. + Capacity tlv.RecordT[tlv.TlvType6, uint64] + + // NodeID1 is the numerically-lesser public key ID of one of the channel + // operators. + NodeID1 tlv.RecordT[tlv.TlvType8, [33]byte] + + // NodeID2 is the numerically-greater public key ID of one of the + // channel operators. + NodeID2 tlv.RecordT[tlv.TlvType10, [33]byte] + + // BitcoinKey1 is the public key of the key used by Node1 in the + // construction of the on-chain funding transaction. This is an optional + // field and only needs to be set if the 4-of-4 MuSig construction was + // used in the creation of the message signature. + BitcoinKey1 tlv.OptionalRecordT[tlv.TlvType12, [33]byte] + + // BitcoinKey2 is the public key of the key used by Node2 in the + // construction of the on-chain funding transaction. This is an optional + // field and only needs to be set if the 4-of-4 MuSig construction was + // used in the creation of the message signature. + BitcoinKey2 tlv.OptionalRecordT[tlv.TlvType14, [33]byte] + + // MerkleRootHash is the hash used to create the optional tweak in the + // funding output. If this is not set but the bitcoin keys are, then + // the funding output is a pure 2-of-2 MuSig aggregate public key. + MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte] + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData ExtraOpaqueData +} + +// Decode deserializes a serialized AnnounceSignatures1 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &c.Signature) + if err != nil { + return err + } + c.Signature.ForceSchnorr() + + return c.DecodeTLVRecords(r) +} + +// DecodeTLVRecords decodes only the TLV section of the message. +func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + ) + typeMap, err := tlvRecords.ExtractRecords( + &chainHash, &c.Features, &c.ShortChannelID, &c.Capacity, + &c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash, + ) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } + + if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { + c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + } + + if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { + c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) + } + + if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { + c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) + } + + if len(tlvRecords) != 0 { + c.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target AnnounceSignatures1 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { + _, err := w.Write(c.Signature.RawBytes()) + if err != nil { + return err + } + _, err = c.DataToSign() + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// DigestToSign computes the digest of the message to be signed. +func (c *ChannelAnnouncement2) DigestToSign() (*chainhash.Hash, error) { + data, err := c.DataToSign() + if err != nil { + return nil, err + } + + return MsgHash(chanAnn2MsgName, chanAnn2SigFieldName, data), nil +} + +// DataToSign encodes the data to be signed into the ExtraOpaqueData member and +// returns it. +func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + var recordProducers []tlv.RecordProducer + if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { + hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + hash.Val = c.ChainHash.Val + + recordProducers = append(recordProducers, &hash) + } + + recordProducers = append(recordProducers, + &c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1, + &c.NodeID2, + ) + + c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) { + recordProducers = append(recordProducers, &key) + }) + + c.MerkleRootHash.WhenSome( + func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) { + recordProducers = append(recordProducers, &hash) + }, + ) + + err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) + if err != nil { + return nil, err + } + + return c.ExtraOpaqueData, nil +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) MsgType() MessageType { + return MsgChannelAnnouncement2 +} + +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.Message interface. +var _ Message = (*ChannelAnnouncement2)(nil) + +// Node1KeyBytes returns the bytes representing the public key of node 1 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) Node1KeyBytes() [33]byte { + return c.NodeID1.Val +} + +// Node2KeyBytes returns the bytes representing the public key of node 2 in the +// channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) Node2KeyBytes() [33]byte { + return c.NodeID2.Val +} + +// GetChainHash returns the hash of the chain which this channel's funding +// transaction is confirmed in. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) GetChainHash() chainhash.Hash { + return c.ChainHash.Val +} + +// SCID returns the short channel ID of the channel. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) SCID() ShortChannelID { + return c.ShortChannelID.Val +} + +// Validate checks that the announcement signature is valid. +// +// NOTE: This is part of the ChannelAnnouncement interface. +func (c *ChannelAnnouncement2) Validate( + fetchPkScript func(id *ShortChannelID) ([]byte, error)) error { + + dataHash, err := c.DigestToSign() + if err != nil { + return err + } + + sig, err := c.Signature.ToSignature() + if err != nil { + return err + } + + nodeKey1, err := btcec.ParsePubKey(c.NodeID1.Val[:]) + if err != nil { + return err + } + + nodeKey2, err := btcec.ParsePubKey(c.NodeID2.Val[:]) + if err != nil { + return err + } + + keys := []*btcec.PublicKey{ + nodeKey1, nodeKey2, + } + + // If the bitcoin keys are provided in the announcement, then it is + // assumed that the signature of the announcement is a 4-of-4 MuSig2 + // over the bitcoin keys and node ID keys. + if c.BitcoinKey1.IsSome() && c.BitcoinKey2.IsSome() { + var ( + btcKey1 tlv.RecordT[tlv.TlvType12, [33]byte] + btcKey2 tlv.RecordT[tlv.TlvType14, [33]byte] + ) + + btcKey1 = c.BitcoinKey1.UnwrapOr(btcKey1) + btcKey2 = c.BitcoinKey2.UnwrapOr(btcKey2) + + bitcoinKey1, err := btcec.ParsePubKey(btcKey1.Val[:]) + if err != nil { + return err + } + + bitcoinKey2, err := btcec.ParsePubKey(btcKey2.Val[:]) + if err != nil { + return err + } + + keys = append(keys, bitcoinKey1, bitcoinKey2) + } else { + // If bitcoin keys are not provided, then we need to get the + // on-chain output key since this will be the 3rd key in the + // 3-of-3 MuSig2 signature. + pkScript, err := fetchPkScript(&c.ShortChannelID.Val) + if err != nil { + return err + } + + outputKey, err := schnorr.ParsePubKey(pkScript[2:]) + if err != nil { + return err + } + + keys = append(keys, outputKey) + } + + aggKey, _, _, err := musig2.AggregateKeys(keys, true) + if err != nil { + return err + } + + if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { + return fmt.Errorf("invalid sig") + } + + return nil +} + +// A compile-time check to ensure that ChannelAnnouncement2 implements the +// ChannelAnnouncement interface. +var _ ChannelAnnouncement = (*ChannelAnnouncement2)(nil) diff --git a/lnwire/channel_announcement_test.go b/lnwire/channel_announcement_test.go new file mode 100644 index 0000000000..87f7da021a --- /dev/null +++ b/lnwire/channel_announcement_test.go @@ -0,0 +1,301 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestChanAnnounce2Validation checks that the various forms of the +// channel_announcement_2 message are validated correctly. +func TestChanAnnounce2Validation(t *testing.T) { + t.Parallel() + + t.Run( + "test 4-of-4 MuSig2 channel announcement", + test4of4MuSig2ChanAnnouncement, + ) + + t.Run( + "test 3-of-3 MuSig2 channel announcement", + test3of3MuSig2ChanAnnouncement, + ) +} + +// test4of4MuSig2ChanAnnouncement covers the case where both bitcoin keys are +// present in the channel announcement. In this case, the signature should be +// a 4-of-4 MuSig2. +func test4of4MuSig2ChanAnnouncement(t *testing.T) { + t.Parallel() + + // Generate the keys for node 1 and node2. + node1, node2 := genChanAnnKeys(t) + + // Build the unsigned channel announcement. + ann := buildUnsignedChanAnnouncement(node1, node2, true) + + // Serialise the bytes that need to be signed. + msg, err := ann.DigestToSign() + require.NoError(t, err) + + var msgBytes [32]byte + copy(msgBytes[:], msg.CloneBytes()) + + // Generate the 4 nonces required for producing the signature. + var ( + node1NodeNonce = genNonceForPubKey(t, node1.nodePub) + node1BtcNonce = genNonceForPubKey(t, node1.btcPub) + node2NodeNonce = genNonceForPubKey(t, node2.nodePub) + node2BtcNonce = genNonceForPubKey(t, node2.btcPub) + ) + + nonceAgg, err := musig2.AggregateNonces([][66]byte{ + node1NodeNonce.PubNonce, + node1BtcNonce.PubNonce, + node2NodeNonce.PubNonce, + node2BtcNonce.PubNonce, + }) + require.NoError(t, err) + + pubKeys := []*btcec.PublicKey{ + node1.nodePub, node2.nodePub, node1.btcPub, node2.btcPub, + } + + // Let Node1 sign the announcement message with its node key. + psA1, err := musig2.Sign( + node1NodeNonce.SecNonce, node1.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node1 sign the announcement message with its bitcoin key. + psA2, err := musig2.Sign( + node1BtcNonce.SecNonce, node1.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its node key. + psB1, err := musig2.Sign( + node2NodeNonce.SecNonce, node2.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its bitcoin key. + psB2, err := musig2.Sign( + node2BtcNonce.SecNonce, node2.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Finally, combine the partial signatures from Node1 and Node2 and add + // the signature to the announcement message. + s := musig2.CombineSigs(psA1.R, []*musig2.PartialSignature{ + psA1, psA2, psB1, psB2, + }) + + sig, err := NewSigFromSignature(s) + require.NoError(t, err) + + ann.Signature = sig + + // Validate the announcement. + require.NoError(t, ann.Validate(nil)) +} + +// test3of3MuSig2ChanAnnouncement covers the case where no bitcoin keys are +// present in the channel announcement. In this case, the signature should be +// a 3-of-3 MuSig2 where the keys making up the pub key are: node1 ID, node2 ID +// and the output key found on-chain in the funding transaction. As the +// verifier, we don't care about the construction of the output key. We only +// care that the channel peers were able to sign for the output key. In reality, +// this key will likely be constructed from at least 1 key from each peer and +// the partial signature for it will be constructed via nested MuSig2 but for +// the sake of the test, we will just have it be backed by a single key. +func test3of3MuSig2ChanAnnouncement(t *testing.T) { + // Generate the keys for node 1 and node 2. + node1, node2 := genChanAnnKeys(t) + + // Build the unsigned channel announcement. + ann := buildUnsignedChanAnnouncement(node1, node2, false) + + // Serialise the bytes that need to be signed. + msg, err := ann.DigestToSign() + require.NoError(t, err) + + var msgBytes [32]byte + copy(msgBytes[:], msg.CloneBytes()) + + // Create a random 3rd key to be used for the output key. + outputKeyPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + outputKey := outputKeyPriv.PubKey() + + // Ensure that the output key has an even Y by negating the private key + // if required. + if outputKey.SerializeCompressed()[0] == + input.PubKeyFormatCompressedOdd { + + outputKeyPriv.Key.Negate() + outputKey = outputKeyPriv.PubKey() + } + + // Generate the nonces required for producing the partial signatures. + var ( + node1NodeNonce = genNonceForPubKey(t, node1.nodePub) + node2NodeNonce = genNonceForPubKey(t, node2.nodePub) + outputKeyNonce = genNonceForPubKey(t, outputKey) + ) + + nonceAgg, err := musig2.AggregateNonces([][66]byte{ + node1NodeNonce.PubNonce, + node2NodeNonce.PubNonce, + outputKeyNonce.PubNonce, + }) + require.NoError(t, err) + + pkScript, err := input.PayToTaprootScript(outputKey) + require.NoError(t, err) + + // We'll pass in a mock tx fetcher that will return the funding output + // containing this key. This is needed since the output key can not be + // determined from the channel announcement itself. + fetchTx := func(chanID *ShortChannelID) ([]byte, error) { + return pkScript, nil + } + + pubKeys := []*btcec.PublicKey{node1.nodePub, node2.nodePub, outputKey} + + // Let Node1 sign the announcement message with its node key. + psA, err := musig2.Sign( + node1NodeNonce.SecNonce, node1.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its node key. + psB, err := musig2.Sign( + node2NodeNonce.SecNonce, node2.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Create a partial sig for the output key. + psO, err := musig2.Sign( + outputKeyNonce.SecNonce, outputKeyPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Finally, combine the partial signatures from Node1 and Node2 and add + // the signature to the announcement message. + s := musig2.CombineSigs(psA.R, []*musig2.PartialSignature{ + psA, psB, psO, + }) + + sig, err := NewSigFromSignature(s) + require.NoError(t, err) + + ann.Signature = sig + + // Validate the announcement. + require.NoError(t, ann.Validate(fetchTx)) +} + +func genNonceForPubKey(t *testing.T, pub *btcec.PublicKey) *musig2.Nonces { + nonce, err := musig2.GenNonces(musig2.WithPublicKey(pub)) + require.NoError(t, err) + + return nonce +} + +type keyRing struct { + nodePriv *btcec.PrivateKey + nodePub *btcec.PublicKey + btcPriv *btcec.PrivateKey + btcPub *btcec.PublicKey +} + +func genChanAnnKeys(t *testing.T) (*keyRing, *keyRing) { + // Let Alice and Bob derive the various keys they need. + aliceNodePrivKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + aliceNodeID := aliceNodePrivKey.PubKey() + + aliceBtcPrivKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + bobNodePrivKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + bobNodeID := bobNodePrivKey.PubKey() + + bobBtcPrivKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + alice := &keyRing{ + nodePriv: aliceNodePrivKey, + nodePub: aliceNodePrivKey.PubKey(), + btcPriv: aliceBtcPrivKey, + btcPub: aliceBtcPrivKey.PubKey(), + } + + bob := &keyRing{ + nodePriv: bobNodePrivKey, + nodePub: bobNodePrivKey.PubKey(), + btcPriv: bobBtcPrivKey, + btcPub: bobBtcPrivKey.PubKey(), + } + + if bytes.Compare( + aliceNodeID.SerializeCompressed(), + bobNodeID.SerializeCompressed(), + ) != -1 { + + return bob, alice + } + + return alice, bob +} + +func buildUnsignedChanAnnouncement(node1, node2 *keyRing, + withBtcKeys bool) *ChannelAnnouncement2 { + + var ann ChannelAnnouncement2 + ann.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + features := NewRawFeatureVector() + ann.Features.Val = *features + ann.ShortChannelID.Val = ShortChannelID{ + BlockHeight: 1000, + TxIndex: 100, + TxPosition: 0, + } + ann.Capacity.Val = 100000 + + copy(ann.NodeID1.Val[:], node1.nodePub.SerializeCompressed()) + copy(ann.NodeID2.Val[:], node2.nodePub.SerializeCompressed()) + + if !withBtcKeys { + return &ann + } + + btcKey1Bytes := tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2Bytes := tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + + copy(btcKey1Bytes.Val[:], node1.btcPub.SerializeCompressed()) + copy(btcKey2Bytes.Val[:], node2.btcPub.SerializeCompressed()) + + ann.BitcoinKey1 = tlv.SomeRecordT(btcKey1Bytes) + ann.BitcoinKey2 = tlv.SomeRecordT(btcKey2Bytes) + + return &ann +} diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index bdcb95ce8d..912a068bde 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -33,6 +33,16 @@ type ChannelReady struct { // to accept a new commitment state transition. NextLocalNonce OptMusig2NonceTLV + // AnnouncementNodeNonce is an optional field that stores a public + // nonce that will be used along with the node's ID key during signing + // of the ChannelAnnouncement2 message. + AnnouncementNodeNonce tlv.OptionalRecordT[tlv.TlvType0, Musig2Nonce] + + // AnnouncementBitcoinNonce is an optional field that stores a public + // nonce that will be used along with the node's bitcoin key during + // signing of the ChannelAnnouncement2 message. + AnnouncementBitcoinNonce tlv.OptionalRecordT[tlv.TlvType2, Musig2Nonce] + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -78,9 +88,11 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { var ( aliasScid ShortChannelID localNonce = c.NextLocalNonce.Zero() + nodeNonce = tlv.ZeroRecordT[tlv.TlvType0, Musig2Nonce]() + btcNonce = tlv.ZeroRecordT[tlv.TlvType2, Musig2Nonce]() ) typeMap, err := tlvRecords.ExtractRecords( - &aliasScid, &localNonce, + &btcNonce, &aliasScid, &nodeNonce, &localNonce, ) if err != nil { return err @@ -94,6 +106,14 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil { c.NextLocalNonce = tlv.SomeRecordT(localNonce) } + val, ok := typeMap[c.AnnouncementBitcoinNonce.TlvType()] + if ok && val == nil { + c.AnnouncementBitcoinNonce = tlv.SomeRecordT(btcNonce) + } + val, ok = typeMap[c.AnnouncementNodeNonce.TlvType()] + if ok && val == nil { + c.AnnouncementNodeNonce = tlv.SomeRecordT(nodeNonce) + } if len(tlvRecords) != 0 { c.ExtraData = tlvRecords @@ -117,13 +137,24 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { } // We'll only encode the AliasScid in a TLV segment if it exists. - recordProducers := make([]tlv.RecordProducer, 0, 2) + recordProducers := make([]tlv.RecordProducer, 0, 4) if c.AliasScid != nil { recordProducers = append(recordProducers, c.AliasScid) } c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { recordProducers = append(recordProducers, &localNonce) }) + c.AnnouncementBitcoinNonce.WhenSome( + func(nonce tlv.RecordT[tlv.TlvType2, Musig2Nonce]) { + recordProducers = append(recordProducers, &nonce) + }, + ) + c.AnnouncementNodeNonce.WhenSome( + func(nonce tlv.RecordT[tlv.TlvType0, Musig2Nonce]) { + recordProducers = append(recordProducers, &nonce) + }, + ) + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 7f42a58b46..bd3852247e 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -5,7 +5,11 @@ import ( "fmt" "io" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/input" ) // ChanUpdateMsgFlags is a bitfield that signals whether optional fields are @@ -56,11 +60,11 @@ func (c ChanUpdateChanFlags) String() string { return fmt.Sprintf("%08b", c) } -// ChannelUpdate message is used after channel has been initially announced. +// ChannelUpdate1 message is used after channel has been initially announced. // Each side independently announces its fees and minimum expiry for HTLCs and // other parameters. Also this message is used to redeclare initially set // channel parameters. -type ChannelUpdate struct { +type ChannelUpdate1 struct { // Signature is used to validate the announced data and prove the // ownership of node id. Signature Sig @@ -122,13 +126,13 @@ type ChannelUpdate struct { // A compile time check to ensure ChannelUpdate implements the lnwire.Message // interface. -var _ Message = (*ChannelUpdate)(nil) +var _ Message = (*ChannelUpdate1)(nil) // Decode deserializes a serialized ChannelUpdate stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { +func (a *ChannelUpdate1) Decode(r io.Reader, pver uint32) error { err := ReadElements(r, &a.Signature, a.ChainHash[:], @@ -159,7 +163,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { +func (a *ChannelUpdate1) Encode(w *bytes.Buffer, pver uint32) error { if err := WriteSig(w, a.Signature); err != nil { return err } @@ -217,13 +221,13 @@ func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { // wire. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate) MsgType() MessageType { +func (a *ChannelUpdate1) MsgType() MessageType { return MsgChannelUpdate } // DataToSign is used to retrieve part of the announcement message which should // be signed. -func (a *ChannelUpdate) DataToSign() ([]byte, error) { +func (a *ChannelUpdate1) DataToSign() ([]byte, error) { // We should not include the signatures itself. b := make([]byte, 0, MaxMsgBody) buf := bytes.NewBuffer(b) @@ -279,3 +283,156 @@ func (a *ChannelUpdate) DataToSign() ([]byte, error) { return buf.Bytes(), nil } + +// Validate validates the channel update's message flags and corresponding +// update fields. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) Validate(capacity btcutil.Amount) error { + // The maxHTLC flag is mandatory. + if !a.MessageFlags.HasMaxHtlc() { + return fmt.Errorf("max htlc flag not set for channel update %v", + spew.Sdump(a)) + } + + maxHtlc := a.HtlcMaximumMsat + if maxHtlc == 0 || maxHtlc < a.HtlcMinimumMsat { + return fmt.Errorf("invalid max htlc for channel update %v", + spew.Sdump(a)) + } + + // For light clients, the capacity will not be set, so we'll skip + // checking whether the MaxHTLC value respects the channel's capacity. + capacityMsat := NewMSatFromSatoshis(capacity) + if capacityMsat != 0 && maxHtlc > capacityMsat { + return fmt.Errorf("max_htlc (%v) for channel update greater "+ + "than capacity (%v)", maxHtlc, capacityMsat) + } + + return nil +} + +// VerifySig verifies that the channel update message was signed by the party +// with the given node public key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) VerifySig(pubKey *btcec.PublicKey) error { + data, err := a.DataToSign() + if err != nil { + return fmt.Errorf("unable to reconstruct message data: %w", err) + } + + dataHash := chainhash.DoubleHashB(data) + nodeSig, err := a.Signature.ToSignature() + if err != nil { + return err + } + if !nodeSig.Verify(dataHash, pubKey) { + return fmt.Errorf("invalid signature for channel update %v", + spew.Sdump(a)) + } + + return nil +} + +// SCID returns the ShortChannelID of the channel that the update applies to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) SCID() ShortChannelID { + return a.ShortChannelID +} + +// IsNode1 is true if the update was produced by node 1 of the channel peers. +// Node 1 is the node with the lexicographically smaller public key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) IsNode1() bool { + return a.ChannelFlags&ChanUpdateDirection == 0 +} + +// IsDisabled is true if the update is announcing that the channel should be +// considered disabled. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) IsDisabled() bool { + return a.ChannelFlags&ChanUpdateDisabled == ChanUpdateDisabled +} + +// GetChainHash returns the hash of the chain that the message is referring to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) GetChainHash() chainhash.Hash { + return a.ChainHash +} + +// ForwardingPolicy returns the set of forwarding constraints of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) ForwardingPolicy() *ForwardingPolicy { + return &ForwardingPolicy{ + TimeLockDelta: a.TimeLockDelta, + BaseFee: MilliSatoshi(a.BaseFee), + FeeRate: MilliSatoshi(a.FeeRate), + MinHTLC: a.HtlcMinimumMsat, + HasMaxHTLC: a.MessageFlags.HasMaxHtlc(), + MaxHTLC: a.HtlcMaximumMsat, + } +} + +// CmpAge can be used to determine if the update is older or newer than the +// passed update. It returns 1 if this update is newer, -1 if it is older, and +// 0 if they are the same age. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) CmpAge(update ChannelUpdate) (CompareResult, error) { + other, ok := update.(*ChannelUpdate1) + if !ok { + return 0, fmt.Errorf("expected *ChannelUpdate1, got: %T", + update) + } + + switch { + case a.Timestamp > other.Timestamp: + return GreaterThan, nil + case a.Timestamp < other.Timestamp: + return LessThan, nil + default: + return EqualTo, nil + } +} + +// SetDisabled can be used to adjust the disabled flag of an update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) SetDisabled(disabled bool) { + if disabled { + a.ChannelFlags |= ChanUpdateDisabled + } else { + a.ChannelFlags &= ^ChanUpdateDisabled + } +} + +// SetSig can be used to adjust the signature of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) SetSig(sig input.Signature) error { + s, err := NewSigFromSignature(sig) + if err != nil { + return err + } + + a.Signature = s + + return nil +} + +// SetSCID can be used to overwrite the SCID of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) { + a.ShortChannelID = scid +} + +// A compile time assertion to ensure ChannelUpdate1 implements the +// ChannelUpdate interface. +var _ ChannelUpdate = (*ChannelUpdate1)(nil) diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go new file mode 100644 index 0000000000..27e8ffe3b1 --- /dev/null +++ b/lnwire/channel_update_2.go @@ -0,0 +1,488 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + defaultCltvExpiryDelta = uint16(80) + defaultHtlcMinMsat = MilliSatoshi(1) + defaultFeeBaseMsat = uint32(1000) + defaultFeeProportionalMillionths = uint32(1) + + // chanUpdate2MsgName is a string representing the name of the + // ChannelUpdate2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelUpdate2 message. + chanUpdate2MsgName = "channel_update_2" + + // chanUpdate2SigField is the name of the signature field of the + // ChannelUpdate2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelUpdate2 message. + chanUpdate2SigField = "bip340_sig" +) + +// ChannelUpdate2 message is used after taproot channel has been initially +// announced. Each side independently announces its fees and minimum expiry for +// HTLCs and other parameters. This message is also used to redeclare initially +// set channel parameters. +type ChannelUpdate2 struct { + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + // Along with the short channel ID, this uniquely identifies the + // channel globally in a blockchain. + ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] + + // ShortChannelID is the unique description of the funding transaction. + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] + + // BlockHeight allows ordering in the case of multiple announcements. We + // should ignore the message if block height is not greater than the + // last-received. The block height must always be greater or equal to + // the block height that the channel funding transaction was confirmed + // in. + BlockHeight tlv.RecordT[tlv.TlvType4, uint32] + + // DisabledFlags is an optional bitfield that describes various reasons + // that the node is communicating that the channel should be considered + // disabled. + DisabledFlags tlv.RecordT[tlv.TlvType6, ChanUpdateDisableFlags] + + // Direction is false if this update was produced by node 1 of the + // channel announcement and true if it is from node 2. + Direction tlv.RecordT[tlv.TlvType8, Boolean] + + // CLTVExpiryDelta is the minimum number of blocks this node requires to + // be added to the expiry of HTLCs. This is a security parameter + // determined by the node operator. This value represents the required + // gap between the time locks of the incoming and outgoing HTLC's set + // to this node. + CLTVExpiryDelta tlv.RecordT[tlv.TlvType10, uint16] + + // HTLCMinimumMsat is the minimum HTLC value which will be accepted. + HTLCMinimumMsat tlv.RecordT[tlv.TlvType12, MilliSatoshi] + + // HtlcMaximumMsat is the maximum HTLC value which will be accepted. + HTLCMaximumMsat tlv.RecordT[tlv.TlvType14, MilliSatoshi] + + // FeeBaseMsat is the base fee that must be used for incoming HTLC's to + // this particular channel. This value will be tacked onto the required + // for a payment independent of the size of the payment. + FeeBaseMsat tlv.RecordT[tlv.TlvType16, uint32] + + // FeeProportionalMillionths is the fee rate that will be charged per + // millionth of a satoshi. + FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] + + // ExtraOpaqueData is the set of data that was appended to this message + // to fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData +} + +// Decode deserializes a serialized ChannelUpdate2 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &c.Signature) + if err != nil { + return err + } + c.Signature.ForceSchnorr() + + return c.DecodeTLVRecords(r) +} + +// DecodeTLVRecords decodes only the TLV section of the message. +func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + typeMap, err := tlvRecords.ExtractRecords( + &chainHash, &c.ShortChannelID, &c.BlockHeight, &c.DisabledFlags, + &c.Direction, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, + &c.HTLCMaximumMsat, &c.FeeBaseMsat, + &c.FeeProportionalMillionths, + ) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } + + // If the CLTV expiry delta was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.CLTVExpiryDelta.TlvType()]; !ok { + c.CLTVExpiryDelta.Val = defaultCltvExpiryDelta + } + + // If the HTLC Minimum msat was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.HTLCMinimumMsat.TlvType()]; !ok { + c.HTLCMinimumMsat.Val = defaultHtlcMinMsat + } + + // If the base fee was not encoded, then set it to the default value. + if _, ok := typeMap[c.FeeBaseMsat.TlvType()]; !ok { + c.FeeBaseMsat.Val = defaultFeeBaseMsat + } + + // If the proportional fee was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.FeeProportionalMillionths.TlvType()]; !ok { + c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:lll + } + + if len(tlvRecords) != 0 { + c.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + _, err := w.Write(c.Signature.RawBytes()) + if err != nil { + return err + } + + _, err = c.DataToSign() + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// DigestTag returns the tag to be used when signing the digest. +func (c *ChannelUpdate2) DigestTag() []byte { + return MsgTag(chanUpdate2MsgName, chanUpdate2SigField) +} + +// DigestToSign computes the digest of the message to be signed. +func (c *ChannelUpdate2) DigestToSign() ([]byte, error) { + data, err := c.DataToSign() + if err != nil { + return nil, err + } + + hash := MsgHash(chanUpdate2MsgName, chanUpdate2SigField, data) + + return hash[:], nil +} + +// DataToSign is used to retrieve part of the announcement message which should +// be signed. For the ChannelUpdate2 message, this includes the serialised TLV +// records. +func (c *ChannelUpdate2) DataToSign() ([]byte, error) { + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + var recordProducers []tlv.RecordProducer + if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { + hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + hash.Val = c.ChainHash.Val + + recordProducers = append(recordProducers, &hash) + } + + recordProducers = append(recordProducers, + &c.ShortChannelID, &c.BlockHeight, + ) + + // Only include the disable flags if any bit is set. + if !c.DisabledFlags.Val.IsEnabled() { + recordProducers = append(recordProducers, &c.DisabledFlags) + } + + // We only need to encode the direction if the direction is set to 1. + if c.Direction.Val.B { + recordProducers = append(recordProducers, &c.Direction) + } + + // We only encode the cltv expiry delta if it is not equal to the + // default. + if c.CLTVExpiryDelta.Val != defaultCltvExpiryDelta { + recordProducers = append(recordProducers, &c.CLTVExpiryDelta) + } + + if c.HTLCMinimumMsat.Val != defaultHtlcMinMsat { + recordProducers = append(recordProducers, &c.HTLCMinimumMsat) + } + + recordProducers = append(recordProducers, &c.HTLCMaximumMsat) + + if c.FeeBaseMsat.Val != defaultFeeBaseMsat { + recordProducers = append(recordProducers, &c.FeeBaseMsat) + } + + if c.FeeProportionalMillionths.Val != defaultFeeProportionalMillionths { + recordProducers = append( + recordProducers, &c.FeeProportionalMillionths, + ) + } + + err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) + if err != nil { + return nil, err + } + + return c.ExtraOpaqueData, nil +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) MsgType() MessageType { + return MsgChannelUpdate2 +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.Message interface. +var _ Message = (*ChannelUpdate2)(nil) + +// SCID returns the ShortChannelID of the channel that the update applies to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SCID() ShortChannelID { + return c.ShortChannelID.Val +} + +// IsNode1 is true if the update was produced by node 1 of the channel peers. +// Node 1 is the node with the lexicographically smaller public key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) IsNode1() bool { + return !c.Direction.Val.B +} + +// IsDisabled is true if the update is announcing that the channel should be +// considered disabled. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) IsDisabled() bool { + return !c.DisabledFlags.Val.IsEnabled() +} + +// GetChainHash returns the hash of the chain that the message is referring to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) GetChainHash() chainhash.Hash { + return c.ChainHash.Val +} + +// ForwardingPolicy returns the set of forwarding constraints of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) ForwardingPolicy() *ForwardingPolicy { + return &ForwardingPolicy{ + TimeLockDelta: c.CLTVExpiryDelta.Val, + BaseFee: MilliSatoshi(c.FeeBaseMsat.Val), + FeeRate: MilliSatoshi(c.FeeProportionalMillionths.Val), + MinHTLC: c.HTLCMinimumMsat.Val, + HasMaxHTLC: true, + MaxHTLC: c.HTLCMaximumMsat.Val, + } +} + +// CmpAge can be used to determine if the update is older or newer than the +// passed update. It returns 1 if this update is newer, -1 if it is older, and +// 0 if they are the same age. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) CmpAge(update ChannelUpdate) (CompareResult, error) { + other, ok := update.(*ChannelUpdate2) + if !ok { + return 0, fmt.Errorf("expected *ChannelUpdate2, got: %T", + update) + } + + switch { + case c.BlockHeight.Val > other.BlockHeight.Val: + return GreaterThan, nil + case c.BlockHeight.Val < other.BlockHeight.Val: + return LessThan, nil + default: + return EqualTo, nil + } +} + +// SetDisabled can be used to adjust the disabled flag of an update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SetDisabled(disabled bool) { + if disabled { + c.DisabledFlags.Val |= ChanUpdateDisableIncoming + c.DisabledFlags.Val |= ChanUpdateDisableOutgoing + } else { + c.DisabledFlags.Val &^= ChanUpdateDisableIncoming + c.DisabledFlags.Val &^= ChanUpdateDisableOutgoing + } +} + +// SetSig can be used to adjust the signature of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SetSig(signature input.Signature) error { + sig, err := NewSigFromSignature(signature) + if err != nil { + return err + } + + c.Signature = sig + + return nil +} + +// SetSCID can be used to overwrite the SCID of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SetSCID(scid ShortChannelID) { + c.ShortChannelID.Val = scid +} + +// Validate validates the sanity of the channel update message +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) Validate(capacity btcutil.Amount) error { + maxHtlc := c.HTLCMaximumMsat.Val + if maxHtlc == 0 || maxHtlc < c.HTLCMinimumMsat.Val { + return fmt.Errorf("invalid max htlc for channel update %v", + spew.Sdump(c)) + } + + // Checking whether the MaxHTLC value respects the channel's + // capacity. + capacityMsat := NewMSatFromSatoshis(capacity) + if maxHtlc > capacityMsat { + return fmt.Errorf("max_htlc (%v) for channel update greater "+ + "than capacity (%v)", maxHtlc, capacityMsat) + } + + return nil +} + +// VerifySig verifies that the message was signed by the given pub key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) VerifySig(pubKey *btcec.PublicKey) error { + digest, err := c.DigestToSign() + if err != nil { + return fmt.Errorf("unable to reconstruct message data: %w", err) + } + + nodeSig, err := c.Signature.ToSignature() + if err != nil { + return err + } + + if !nodeSig.Verify(digest, pubKey) { + return fmt.Errorf("invalid signature for channel update %v", + spew.Sdump(c)) + } + + return nil +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.ChannelUpdate interface. +var _ ChannelUpdate = (*ChannelUpdate2)(nil) + +// ChanUpdateDisableFlags is a bit vector that can be used to indicate various +// reasons for the channel being marked as disabled. +type ChanUpdateDisableFlags uint8 + +const ( + // ChanUpdateDisableIncoming is a bit indicates that a channel is + // disabled in the inbound direction meaning that the node broadcasting + // the update is communicating that they cannot receive funds. + ChanUpdateDisableIncoming ChanUpdateDisableFlags = 1 << iota + + // ChanUpdateDisableOutgoing is a bit indicates that a channel is + // disabled in the outbound direction meaning that the node broadcasting + // the update is communicating that they cannot send or route funds. + ChanUpdateDisableOutgoing = 2 +) + +// IncomingDisabled returns true if the ChanUpdateDisableIncoming bit is set. +func (c ChanUpdateDisableFlags) IncomingDisabled() bool { + return c&ChanUpdateDisableIncoming == ChanUpdateDisableIncoming +} + +// OutgoingDisabled returns true if the ChanUpdateDisableOutgoing bit is set. +func (c ChanUpdateDisableFlags) OutgoingDisabled() bool { + return c&ChanUpdateDisableOutgoing == ChanUpdateDisableOutgoing +} + +// IsEnabled returns true if none of the disable bits are set. +func (c ChanUpdateDisableFlags) IsEnabled() bool { + return c == 0 +} + +// String returns the bitfield flags as a string. +func (c ChanUpdateDisableFlags) String() string { + return fmt.Sprintf("%08b", c) +} + +// Record returns the tlv record for the disable flags. +func (c *ChanUpdateDisableFlags) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 1, encodeDisableFlags, + decodeDisableFlags) +} + +func encodeDisableFlags(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ChanUpdateDisableFlags); ok { + flagsInt := uint8(*v) + + return tlv.EUint8(w, &flagsInt, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.ChanUpdateDisableFlags") +} + +func decodeDisableFlags(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ChanUpdateDisableFlags); ok { + var flagsInt uint8 + err := tlv.DUint8(r, &flagsInt, buf, l) + if err != nil { + return err + } + + *v = ChanUpdateDisableFlags(flagsInt) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.ChanUpdateDisableFlags", + l, l) +} diff --git a/lnwire/features.go b/lnwire/features.go index 437a50dc21..c5ae014f79 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -642,31 +642,39 @@ func (fv *RawFeatureVector) sizeFunc() uint64 { // Record returns a TLV record that can be used to encode/decode raw feature // vectors. Note that the length of the feature vector is not included, because // it is covered by the TLV record's length field. -func (fv *RawFeatureVector) Record(recordType tlv.Type) tlv.Record { +func (fv *RawFeatureVector) Record() tlv.Record { return tlv.MakeDynamicRecord( - recordType, fv, fv.sizeFunc, rawFeatureEncoder, - rawFeatureDecoder, + 0, fv, fv.sizeFunc, rawFeatureEncoder, rawFeatureDecoder, ) } // rawFeatureEncoder is a custom TLV encoder for raw feature vectors. func rawFeatureEncoder(w io.Writer, val interface{}, _ *[8]byte) error { - if f, ok := val.(*RawFeatureVector); ok { - return f.encode(w, f.SerializeSize(), 8) + if v, ok := val.(*RawFeatureVector); ok { + // Encode the feature bits as a byte slice without its length + // prepended, as that's already taken care of by the TLV record. + fv := *v + return fv.encode(w, fv.SerializeSize(), 8) } - return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") + return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector") } // rawFeatureDecoder is a custom TLV decoder for raw feature vectors. func rawFeatureDecoder(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { - if f, ok := val.(*RawFeatureVector); ok { - return f.decode(r, int(l), 8) + if v, ok := val.(*RawFeatureVector); ok { + fv := NewRawFeatureVector() + if err := fv.decode(r, int(l), 8); err != nil { + return err + } + *v = *fv + + return nil } - return tlv.NewTypeForEncodingErr(val, "*lnwire.RawFeatureVector") + return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector") } // FeatureVector represents a set of enabled features. The set stores diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index d2dbddecbc..7b628752a0 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -5,6 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/tlv" ) // GossipTimestampRange is a message that allows the sender to restrict the set @@ -17,15 +18,31 @@ type GossipTimestampRange struct { ChainHash chainhash.Hash // FirstTimestamp is the timestamp of the earliest announcement message - // that should be sent by the receiver. + // that should be sent by the receiver. This is only to be used for + // querying message of gossip 1.0 which are timestamped using Unix + // timestamps. FirstBlockHeight and BlockRange should be used to + // query for announcement messages timestamped using block heights. FirstTimestamp uint32 // TimestampRange is the horizon beyond the FirstTimestamp that any // announcement messages should be sent for. The receiving node MUST // NOT send any announcements that have a timestamp greater than - // FirstTimestamp + TimestampRange. + // FirstTimestamp + TimestampRange. This is used together with + // FirstTimestamp to query for gossip 1.0 messages timestamped with + // Unix timestamps. TimestampRange uint32 + // FirstBlockHeight is the height of earliest announcement message that + // should be sent by the receiver. This is used only for querying + // announcement messages that use block heights as a timestamp. + FirstBlockHeight tlv.OptionalRecordT[tlv.TlvType2, uint32] + + // BlockRange is the horizon beyond FirstBlockHeight that any + // announcement messages should be sent for. The receiving node MUST NOT + // send any announcements that have a timestamp greater than + // FirstBlockHeight + BlockRange. + BlockRange tlv.OptionalRecordT[tlv.TlvType4, uint32] + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -45,13 +62,42 @@ var _ Message = (*GossipTimestampRange)(nil) // passed io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, +func (g *GossipTimestampRange) Decode(r io.Reader, _ uint32) error { + err := ReadElements(r, g.ChainHash[:], &g.FirstTimestamp, &g.TimestampRange, - &g.ExtraData, ) + if err != nil { + return err + } + + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var ( + firstBlock = tlv.ZeroRecordT[tlv.TlvType2, uint32]() + blockRange = tlv.ZeroRecordT[tlv.TlvType4, uint32]() + ) + typeMap, err := tlvRecords.ExtractRecords(&firstBlock, &blockRange) + if err != nil { + return err + } + + if val, ok := typeMap[g.FirstBlockHeight.TlvType()]; ok && val == nil { + g.FirstBlockHeight = tlv.SomeRecordT(firstBlock) + } + if val, ok := typeMap[g.BlockRange.TlvType()]; ok && val == nil { + g.BlockRange = tlv.SomeRecordT(blockRange) + } + + if len(tlvRecords) != 0 { + g.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target GossipTimestampRange into the passed io.Writer @@ -71,6 +117,22 @@ func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error { return err } + recordProducers := make([]tlv.RecordProducer, 0, 2) + g.FirstBlockHeight.WhenSome( + func(height tlv.RecordT[tlv.TlvType2, uint32]) { + recordProducers = append(recordProducers, &height) + }, + ) + g.BlockRange.WhenSome( + func(blockRange tlv.RecordT[tlv.TlvType4, uint32]) { + recordProducers = append(recordProducers, &blockRange) + }, + ) + err := EncodeMessageExtraData(&g.ExtraData, recordProducers...) + if err != nil { + return err + } + return WriteBytes(w, g.ExtraData) } diff --git a/lnwire/interfaces.go b/lnwire/interfaces.go new file mode 100644 index 0000000000..79c665254c --- /dev/null +++ b/lnwire/interfaces.go @@ -0,0 +1,154 @@ +package lnwire + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/input" +) + +// AnnounceSignatures is an interface that represents a message used to +// exchange signatures of a ChannelAnnouncment message during the funding flow. +type AnnounceSignatures interface { + // SCID returns the ShortChannelID of the channel. + SCID() ShortChannelID + + // ChanID returns the ChannelID identifying the channel. + ChanID() ChannelID + + Message +} + +// ChannelAnnouncement is an interface that must be satisfied by any message +// used to announce and prove the existence of a channel. +type ChannelAnnouncement interface { + // SCID returns the short channel ID of the channel. + SCID() ShortChannelID + + // GetChainHash returns the hash of the chain which this channel's + // funding transaction is confirmed in. + GetChainHash() chainhash.Hash + + // Node1KeyBytes returns the bytes representing the public key of node + // 1 in the channel. + Node1KeyBytes() [33]byte + + // Node2KeyBytes returns the bytes representing the public key of node + // 2 in the channel. + Node2KeyBytes() [33]byte + + // Validate checks the various signatures of the announcement. + Validate(fetchPKScript func(id *ShortChannelID) ([]byte, error)) error + + Message +} + +// ValidateChannelUpdateAnn validates the channel update announcement by +// checking (1) that the included signature covers the announcement and has been +// signed by the node's private key, and (2) that the announcement's message +// flags and optional fields are sane. +func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, + a ChannelUpdate) error { + + if err := a.Validate(capacity); err != nil { + return err + } + + return a.VerifySig(pubKey) +} + +// CompareResult represents the result after comparing two things. +type CompareResult uint8 + +const ( + // LessThan indicates that base object is less than the object it was + // compared to. + LessThan CompareResult = iota + + // EqualTo indicates that the base object is equal to the object it was + // compared to. + EqualTo + + // GreaterThan indicates that base object is greater than the object it + // was compared to. + GreaterThan +) + +// ChannelUpdate is an interface that describes a message used to update the +// forwarding rules of a channel. +// +//nolint:interfacebloat +type ChannelUpdate interface { + // SCID returns the ShortChannelID of the channel that the update + // applies to. + SCID() ShortChannelID + + // IsNode1 is true if the update was produced by node 1 of the channel + // peers. Node 1 is the node with the lexicographically smaller public + // key. + IsNode1() bool + + // IsDisabled is true if the update is announcing that the channel + // should be considered disabled. + IsDisabled() bool + + // GetChainHash returns the hash of the chain that the message is + // referring to. + GetChainHash() chainhash.Hash + + // ForwardingPolicy returns the set of forwarding constraints of the + // update. + ForwardingPolicy() *ForwardingPolicy + + // CmpAge can be used to determine if the update is older or newer than + // the passed update. It returns LessThan if this update is older than + // the passed update, GreaterThan if it is newer and EqualTo if they are + // the same age. + CmpAge(update ChannelUpdate) (CompareResult, error) + + // SetDisabled can be used to adjust the disabled flag of an update. + SetDisabled(bool) + + // SetSig can be used to adjust the signature of the update. + SetSig(signature input.Signature) error + + // SetSCID can be used to overwrite the SCID of the update. + SetSCID(scid ShortChannelID) + + // Validate validates the sanity of the channel update message. + Validate(capacity btcutil.Amount) error + + // VerifySig verifies that the message was signed by the given pub key. + VerifySig(pubKey *btcec.PublicKey) error + + Message +} + +// ForwardingPolicy defines the set of forwarding constraints advertised in a +// ChannelUpdate message. +type ForwardingPolicy struct { + // TimeLockDelta is the minimum number of blocks that the node requires + // to be added to the expiry of HTLCs. This is a security parameter + // determined by the node operator. This value represents the required + // gap between the time locks of the incoming and outgoing HTLC's set + // to this node. + TimeLockDelta uint16 + + // BaseFee is the base fee that must be used for incoming HTLC's to + // this particular channel. This value will be tacked onto the required + // for a payment independent of the size of the payment. + BaseFee MilliSatoshi + + // FeeRate is the fee rate that will be charged per millionth of a + // satoshi. + FeeRate MilliSatoshi + + // HtlcMinimumMsat is the minimum HTLC value which will be accepted. + MinHTLC MilliSatoshi + + // HasMaxHTLC is true if the MaxHTLC field is provided in the update. + HasMaxHTLC bool + + // HtlcMaximumMsat is the maximum HTLC value which will be accepted. + MaxHTLC MilliSatoshi +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 8ab082b0bd..866128ca67 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -936,6 +936,19 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = addrBytes[:length] + case *PartialSig: + var sBytes [32]byte + if _, err := io.ReadFull(r, sBytes[:]); err != nil { + return err + } + + var s btcec.ModNScalar + s.SetBytes(&sBytes) + + *e = PartialSig{ + Sig: s, + } + case *ExtraOpaqueData: return e.Decode(r) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 122a996601..554e17d33d 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -18,6 +18,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn" @@ -29,17 +30,25 @@ import ( ) var ( - shaHash1Bytes, _ = hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") - shaHash1, _ = chainhash.NewHash(shaHash1Bytes) - outpoint1 = wire.NewOutPoint(shaHash1, 0) - - testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7") - testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae") - testRScalar = new(btcec.ModNScalar) - testSScalar = new(btcec.ModNScalar) - _ = testRScalar.SetByteSlice(testRBytes) - _ = testSScalar.SetByteSlice(testSBytes) - testSig = ecdsa.NewSignature(testRScalar, testSScalar) + shaHash1Bytes, _ = hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb" + + "92427ae41e4649b934ca495991b7852b855") + + shaHash1, _ = chainhash.NewHash(shaHash1Bytes) + outpoint1 = wire.NewOutPoint(shaHash1, 0) + + testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571" + + "319d18e949ddfa2965fb6caa1bf0314f882d7") + testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a" + + "88121167221b6700d72a0ead154c03be696a292d24ae") + testRScalar = new(btcec.ModNScalar) + testSScalar = new(btcec.ModNScalar) + _ = testRScalar.SetByteSlice(testRBytes) + _ = testSScalar.SetByteSlice(testSBytes) + testSig = ecdsa.NewSignature(testRScalar, testSScalar) + testSchnorrSigStr, _ = hex.DecodeString("04E7F9037658A92AFEB4F2" + + "5BAE5339E3DDCA81A353493827D26F16D92308E49E2A25E9220867" + + "8A2DF86970DA91B03A8AF8815A8A60498B358DAF560B347AA557") + testSchnorrSig, _ = NewSigFromSchnorrRawSignature(testSchnorrSigStr) ) const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -134,17 +143,15 @@ func randPubKey() (*btcec.PublicKey, error) { return priv.PubKey(), nil } -func randRawKey() ([33]byte, error) { +func randRawKey(t *testing.T) [33]byte { var n [33]byte priv, err := btcec.NewPrivateKey() - if err != nil { - return n, err - } + require.NoError(t, err) copy(n[:], priv.PubKey().SerializeCompressed()) - return n, nil + return n } func randDeliveryAddress(r *rand.Rand) (DeliveryAddress, error) { @@ -677,18 +684,13 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { var c [32]byte - if _, err := r.Read(c[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } + _, err := r.Read(c[:]) + require.NoError(t, err) pubKey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } + require.NoError(t, err) - req := NewChannelReady(ChannelID(c), pubKey) + req := NewChannelReady(c, pubKey) if r.Int31()%2 == 0 { scid := NewShortChanIDFromInt(uint64(r.Int63())) @@ -698,6 +700,24 @@ func TestLightningWireProtocol(t *testing.T) { req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) } + if r.Int31()%2 == 0 { + nodeNonce := tlv.ZeroRecordT[ + tlv.TlvType0, Musig2Nonce, + ]() + nodeNonce.Val = randLocalNonce(r) + req.AnnouncementNodeNonce = tlv.SomeRecordT( + nodeNonce, + ) + + btcNonce := tlv.ZeroRecordT[ + tlv.TlvType2, Musig2Nonce, + ]() + btcNonce.Val = randLocalNonce(r) + req.AnnouncementBitcoinNonce = tlv.SomeRecordT( + btcNonce, + ) + } + v[0] = reflect.ValueOf(*req) }, MsgShutdown: func(v []reflect.Value, r *rand.Rand) { @@ -926,8 +946,14 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error - req := ChannelAnnouncement{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + req := ChannelAnnouncement1{ + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + NodeID1: randRawKey(t), + NodeID2: randRawKey(t), + BitcoinKey1: randRawKey(t), + BitcoinKey2: randRawKey(t), Features: randRawFeatureVector(r), ExtraOpaqueData: make([]byte, 0), } @@ -952,26 +978,6 @@ func TestLightningWireProtocol(t *testing.T) { return } - req.NodeID1, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.NodeID2, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.BitcoinKey1, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.BitcoinKey2, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } if _, err := r.Read(req.ChainHash[:]); err != nil { t.Fatalf("unable to generate chain hash: %v", err) return @@ -993,6 +999,7 @@ func TestLightningWireProtocol(t *testing.T) { MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := NodeAnnouncement{ + NodeID: randRawKey(t), Features: randRawFeatureVector(r), Timestamp: uint32(r.Int31()), Alias: randAlias(r), @@ -1009,12 +1016,6 @@ func TestLightningWireProtocol(t *testing.T) { return } - req.NodeID, err = randRawKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } - req.Addresses, err = randAddrs(r) if err != nil { t.Fatalf("unable to generate addresses: %v", err) @@ -1047,8 +1048,10 @@ func TestLightningWireProtocol(t *testing.T) { maxHtlc = 0 } - req := ChannelUpdate{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + req := ChannelUpdate1{ + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), Timestamp: uint32(r.Int31()), MessageFlags: msgFlags, ChannelFlags: ChanUpdateChanFlags(r.Int31()), @@ -1085,7 +1088,7 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { var err error - req := AnnounceSignatures{ + req := AnnounceSignatures1{ ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), ExtraOpaqueData: make([]byte, 0), } @@ -1149,6 +1152,35 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, + MsgGossipTimestampRange: func(v []reflect.Value, r *rand.Rand) { + req := GossipTimestampRange{ + FirstTimestamp: rand.Uint32(), + TimestampRange: rand.Uint32(), + ExtraData: make([]byte, 0), + } + + _, err := rand.Read(req.ChainHash[:]) + require.NoError(t, err) + + // Sometimes add a block range. + if r.Int31()%2 == 0 { + firstBlock := tlv.ZeroRecordT[ + tlv.TlvType2, uint32, + ]() + firstBlock.Val = rand.Uint32() + req.FirstBlockHeight = tlv.SomeRecordT( + firstBlock, + ) + + blockRange := tlv.ZeroRecordT[ + tlv.TlvType4, uint32, + ]() + blockRange.Val = rand.Uint32() + req.BlockRange = tlv.SomeRecordT(blockRange) + } + + v[0] = reflect.ValueOf(req) + }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { req := QueryShortChanIDs{ ExtraData: make([]byte, 0), @@ -1389,6 +1421,174 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(*req) }, + MsgAnnounceSignatures2: func(v []reflect.Value, + r *rand.Rand) { + + req := AnnounceSignatures2{ + ShortChannelID: NewShortChanIDFromInt( + uint64(r.Int63()), + ), + ExtraOpaqueData: make([]byte, 0), + } + + _, err := r.Read(req.ChannelID[:]) + require.NoError(t, err) + + partialSig, err := randPartialSig(r) + require.NoError(t, err) + + req.PartialSignature = *partialSig + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(req) + }, + MsgChannelAnnouncement2: func(v []reflect.Value, r *rand.Rand) { + req := ChannelAnnouncement2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + } + + req.ShortChannelID.Val = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + req.Capacity.Val = rand.Uint64() + + req.Features.Val = *randRawFeatureVector(r) + + req.NodeID1.Val = randRawKey(t) + req.NodeID2.Val = randRawKey(t) + + // Sometimes set chain hash to bitcoin mainnet genesis + // hash. + req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(req.ChainHash.Val[:]) + require.NoError(t, err) + } + + // Sometimes set the bitcoin keys. + if r.Int31()%2 == 0 { + btcKey1 := tlv.ZeroRecordT[ + tlv.TlvType12, [33]byte, + ]() + btcKey1.Val = randRawKey(t) + req.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + + btcKey2 := tlv.ZeroRecordT[ + tlv.TlvType14, [33]byte, + ]() + btcKey2.Val = randRawKey(t) + req.BitcoinKey2 = tlv.SomeRecordT(btcKey2) + + // Occasionally also set the merkle root hash. + if r.Int31()%2 == 0 { + hash := tlv.ZeroRecordT[ + tlv.TlvType16, [32]byte, + ]() + + _, err := r.Read(hash.Val[:]) + require.NoError(t, err) + + req.MerkleRootHash = tlv.SomeRecordT( + hash, + ) + } + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(req) + }, + MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { + req := ChannelUpdate2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + } + + req.ShortChannelID.Val = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + req.BlockHeight.Val = r.Uint32() + req.HTLCMaximumMsat.Val = MilliSatoshi(r.Uint64()) + + // Sometimes set chain hash to bitcoin mainnet genesis + // hash. + req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(req.ChainHash.Val[:]) + require.NoError(t, err) + } + + // Sometimes use default htlc min msat. + req.HTLCMinimumMsat.Val = defaultHtlcMinMsat + if r.Int31()%2 == 0 { + req.HTLCMinimumMsat.Val = MilliSatoshi( + r.Uint64(), + ) + } + + // Sometimes set the cltv expiry delta to the default. + req.CLTVExpiryDelta.Val = defaultCltvExpiryDelta + if r.Int31()%2 == 0 { + req.CLTVExpiryDelta.Val = uint16(r.Int31()) + } + + // Sometimes use default fee base. + req.FeeBaseMsat.Val = defaultFeeBaseMsat + if r.Int31()%2 == 0 { + req.FeeBaseMsat.Val = r.Uint32() + } + + // Sometimes use default proportional fee. + req.FeeProportionalMillionths.Val = + defaultFeeProportionalMillionths + if r.Int31()%2 == 0 { + req.FeeProportionalMillionths.Val = r.Uint32() + } + + // Alternate between the two direction possibilities. + if r.Int31()%2 == 0 { + req.Direction.Val.B = true + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags.Val |= + ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags.Val |= + ChanUpdateDisableOutgoing + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(req) + }, } // With the above types defined, we'll now generate a slice of @@ -1553,7 +1753,7 @@ func TestLightningWireProtocol(t *testing.T) { }, { msgType: MsgChannelAnnouncement, - scenario: func(m ChannelAnnouncement) bool { + scenario: func(m ChannelAnnouncement1) bool { return mainScenario(&m) }, }, @@ -1565,13 +1765,13 @@ func TestLightningWireProtocol(t *testing.T) { }, { msgType: MsgChannelUpdate, - scenario: func(m ChannelUpdate) bool { + scenario: func(m ChannelUpdate1) bool { return mainScenario(&m) }, }, { msgType: MsgAnnounceSignatures, - scenario: func(m AnnounceSignatures) bool { + scenario: func(m AnnounceSignatures1) bool { return mainScenario(&m) }, }, @@ -1617,6 +1817,24 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgAnnounceSignatures2, + scenario: func(m AnnounceSignatures2) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgChannelAnnouncement2, + scenario: func(m ChannelAnnouncement2) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgChannelUpdate2, + scenario: func(m ChannelUpdate2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index 2bf64a3131..a758db000d 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -52,11 +52,14 @@ const ( MsgNodeAnnouncement = 257 MsgChannelUpdate = 258 MsgAnnounceSignatures = 259 + MsgAnnounceSignatures2 = 260 MsgQueryShortChanIDs = 261 MsgReplyShortChanIDsEnd = 262 MsgQueryChannelRange = 263 MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 + MsgChannelAnnouncement2 = 267 + MsgChannelUpdate2 = 271 MsgKickoffSig = 777 ) @@ -155,6 +158,12 @@ func (t MessageType) String() string { return "ClosingComplete" case MsgClosingSig: return "ClosingSig" + case MsgAnnounceSignatures2: + return "MsgAnnounceSignatures2" + case MsgChannelAnnouncement2: + return "ChannelAnnouncement2" + case MsgChannelUpdate2: + return "ChannelUpdate2" default: return "" } @@ -259,15 +268,15 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { case MsgError: msg = &Error{} case MsgChannelAnnouncement: - msg = &ChannelAnnouncement{} + msg = &ChannelAnnouncement1{} case MsgChannelUpdate: - msg = &ChannelUpdate{} + msg = &ChannelUpdate1{} case MsgNodeAnnouncement: msg = &NodeAnnouncement{} case MsgPing: msg = &Ping{} case MsgAnnounceSignatures: - msg = &AnnounceSignatures{} + msg = &AnnounceSignatures1{} case MsgPong: msg = &Pong{} case MsgQueryShortChanIDs: @@ -284,6 +293,12 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &ClosingComplete{} case MsgClosingSig: msg = &ClosingSig{} + case MsgAnnounceSignatures2: + msg = &AnnounceSignatures2{} + case MsgChannelAnnouncement2: + msg = &ChannelAnnouncement2{} + case MsgChannelUpdate2: + msg = &ChannelUpdate2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. diff --git a/lnwire/message_test.go b/lnwire/message_test.go index bbb434785f..d42d579115 100644 --- a/lnwire/message_test.go +++ b/lnwire/message_test.go @@ -645,11 +645,11 @@ func newMsgChannelReestablish(t testing.TB, } func newMsgChannelAnnouncement(t testing.TB, - r *rand.Rand) *lnwire.ChannelAnnouncement { + r *rand.Rand) *lnwire.ChannelAnnouncement1 { t.Helper() - msg := &lnwire.ChannelAnnouncement{ + msg := &lnwire.ChannelAnnouncement1{ ShortChannelID: lnwire.NewShortChanIDFromInt(uint64(r.Int63())), Features: rawFeatureVector(), NodeID1: randRawKey(t), @@ -692,7 +692,7 @@ func newMsgNodeAnnouncement(t testing.TB, return msg } -func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate { +func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate1 { t.Helper() msgFlags := lnwire.ChanUpdateMsgFlags(r.Int31()) @@ -706,7 +706,7 @@ func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate { maxHtlc = 0 } - msg := &lnwire.ChannelUpdate{ + msg := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(r.Uint64()), Timestamp: uint32(r.Int31()), MessageFlags: msgFlags, @@ -727,11 +727,11 @@ func newMsgChannelUpdate(t testing.TB, r *rand.Rand) *lnwire.ChannelUpdate { } func newMsgAnnounceSignatures(t testing.TB, - r *rand.Rand) *lnwire.AnnounceSignatures { + r *rand.Rand) *lnwire.AnnounceSignatures1 { t.Helper() - msg := &lnwire.AnnounceSignatures{ + msg := &lnwire.AnnounceSignatures1{ ShortChannelID: lnwire.NewShortChanIDFromInt( uint64(r.Int63()), ), diff --git a/lnwire/msat.go b/lnwire/msat.go index 7473d72c82..2966e5ddb2 100644 --- a/lnwire/msat.go +++ b/lnwire/msat.go @@ -2,8 +2,10 @@ package lnwire import ( "fmt" + "io" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -49,3 +51,40 @@ func (m MilliSatoshi) String() string { } // TODO(roasbeef): extend with arithmetic operations? + +// Record returns a TLV record that can be used to encode/decode a MilliSatoshi +// to/from a TLV stream. +func (m *MilliSatoshi) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, m, tlv.SizeBigSize(m), encodeMilliSatoshis, + decodeMilliSatoshis, + ) +} + +func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MilliSatoshi); ok { + bigSize := uint64(*v) + + return tlv.EBigSize(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi") +} + +func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*MilliSatoshi); ok { + var bigSize uint64 + err := tlv.DBigSize(r, &bigSize, buf, l) + if err != nil { + return err + } + + *v = MilliSatoshi(bigSize) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l) +} diff --git a/lnwire/msg_hash.go b/lnwire/msg_hash.go new file mode 100644 index 0000000000..5812322c07 --- /dev/null +++ b/lnwire/msg_hash.go @@ -0,0 +1,30 @@ +package lnwire + +import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// MsgHashTag will prefix the message name and the field name in order to +// construct the message tag. +const MsgHashTag = "lightning" + +// MsgTag computes the full tag that will be used to prefix a message before +// calculating the tagged hash. The tag is constructed as follows: +// +// tag = "lightning"||"msg_name"||"field_name" +func MsgTag(msgName, fieldName string) []byte { + tag := []byte(MsgHashTag) + tag = append(tag, []byte(msgName)...) + + return append(tag, []byte(fieldName)...) +} + +// MsgHash computes the tagged hash of the given message as follows: +// +// tag = "lightning"||"msg_name"||"field_name" +// hash = sha256(sha246(tag) || sha256(tag) || msg) +func MsgHash(msgName, fieldName string, msg []byte) *chainhash.Hash { + tag := MsgTag(msgName, fieldName) + + return chainhash.TaggedHash(tag, msg) +} diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 66db6a9f58..9a669271d9 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -600,7 +600,7 @@ func (f *FailInvalidOnionKey) Error() string { // unable to pull out a fully valid version, then we'll fall back to the // regular parsing mechanism which includes the length prefix an NO type byte. func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, - chanUpdate *ChannelUpdate, pver uint32) error { + chanUpdate *ChannelUpdate1, pver uint32) error { // Instantiate a LimitReader because there may be additional data // present after the channel update. Without limiting the stream, the @@ -647,11 +647,13 @@ type FailTemporaryChannelFailure struct { // which caused the failure. // // NOTE: This field is optional. - Update *ChannelUpdate + Update *ChannelUpdate1 } // NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. -func NewTemporaryChannelFailure(update *ChannelUpdate) *FailTemporaryChannelFailure { +func NewTemporaryChannelFailure( + update *ChannelUpdate1) *FailTemporaryChannelFailure { + return &FailTemporaryChannelFailure{Update: update} } @@ -685,7 +687,7 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { } if length != 0 { - f.Update = &ChannelUpdate{} + f.Update = &ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, f.Update, pver, @@ -720,12 +722,12 @@ type FailAmountBelowMinimum struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. func NewAmountBelowMinimum(htlcMsat MilliSatoshi, - update ChannelUpdate) *FailAmountBelowMinimum { + update ChannelUpdate1) *FailAmountBelowMinimum { return &FailAmountBelowMinimum{ HtlcMsat: htlcMsat, @@ -761,7 +763,7 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -790,12 +792,12 @@ type FailFeeInsufficient struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewFeeInsufficient creates new instance of the FailFeeInsufficient. func NewFeeInsufficient(htlcMsat MilliSatoshi, - update ChannelUpdate) *FailFeeInsufficient { + update ChannelUpdate1) *FailFeeInsufficient { return &FailFeeInsufficient{ HtlcMsat: htlcMsat, Update: update, @@ -830,7 +832,7 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -861,12 +863,12 @@ type FailIncorrectCltvExpiry struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. func NewIncorrectCltvExpiry(cltvExpiry uint32, - update ChannelUpdate) *FailIncorrectCltvExpiry { + update ChannelUpdate1) *FailIncorrectCltvExpiry { return &FailIncorrectCltvExpiry{ CltvExpiry: cltvExpiry, @@ -899,7 +901,7 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -924,11 +926,11 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { type FailExpiryTooSoon struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. -func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { +func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon { return &FailExpiryTooSoon{ Update: update, } @@ -957,7 +959,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -983,11 +985,13 @@ type FailChannelDisabled struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate + Update ChannelUpdate1 } // NewChannelDisabled creates new instance of the FailChannelDisabled. -func NewChannelDisabled(flags uint16, update ChannelUpdate) *FailChannelDisabled { +func NewChannelDisabled(flags uint16, + update ChannelUpdate1) *FailChannelDisabled { + return &FailChannelDisabled{ Flags: flags, Update: update, @@ -1022,7 +1026,7 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate{} + f.Update = ChannelUpdate1{} return parseChannelUpdateCompatibilityMode( r, length, &f.Update, pver, @@ -1510,7 +1514,7 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { // writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index a06c572dc4..37cd94b8ac 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,7 +20,7 @@ var ( testType = uint64(3) testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) - testChannelUpdate = ChannelUpdate{ + testChannelUpdate = ChannelUpdate1{ Signature: sig, ShortChannelID: NewShortChanIDFromInt(1), Timestamp: 1, @@ -137,7 +137,7 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // Now that we have the set of bytes encoded, we'll ensure that we're // able to decode it using our compatibility method, as it's a regular // encoded channel update message. - var newChanUpdate ChannelUpdate + var newChanUpdate ChannelUpdate1 err := parseChannelUpdateCompatibilityMode( &b, uint16(b.Len()), &newChanUpdate, 0, ) @@ -164,7 +164,7 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // We should be able to properly parse the encoded channel update // message even with the extra two bytes. - var newChanUpdate2 ChannelUpdate + var newChanUpdate2 ChannelUpdate1 err = parseChannelUpdateCompatibilityMode( &b, uint16(b.Len()), &newChanUpdate2, 0, ) diff --git a/lnwire/writer.go b/lnwire/writer.go index fa6247de0b..31a0facb21 100644 --- a/lnwire/writer.go +++ b/lnwire/writer.go @@ -173,6 +173,14 @@ func WriteSigs(buf *bytes.Buffer, sigs []Sig) error { return nil } +// WritePartialSig appends the serialised partial signature to the provided +// buffer. +func WritePartialSig(buf *bytes.Buffer, sig PartialSig) error { + sigBytes := sig.Sig.Bytes() + + return WriteBytes(buf, sigBytes[:]) +} + // WriteFailCode appends the FailCode to the provided buffer. func WriteFailCode(buf *bytes.Buffer, e FailCode) error { return WriteUint16(buf, uint16(e)) diff --git a/netann/chan_status_manager.go b/netann/chan_status_manager.go index f1e6aa578a..c4db4009dc 100644 --- a/netann/chan_status_manager.go +++ b/netann/chan_status_manager.go @@ -60,7 +60,7 @@ type ChanStatusConfig struct { // ApplyChannelUpdate processes new ChannelUpdates signed by our node by // updating our local routing table and broadcasting the update to our // peers. - ApplyChannelUpdate func(*lnwire.ChannelUpdate, *wire.OutPoint, + ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint, bool) error // DB stores the set of channels that are to be monitored. @@ -650,7 +650,7 @@ func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint, // in case our ChannelEdgePolicy is not found in the database. Also returns if // the channel is private by checking AuthProof for nil. func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) ( - *lnwire.ChannelUpdate, bool, error) { + *lnwire.ChannelUpdate1, bool, error) { // Get the edge info and policies for this channel from the graph. info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op) diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 1e64a53f8c..7af98bd801 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -66,8 +66,8 @@ func createChannel(t *testing.T) *channeldb.OpenChannel { // our `pubkey` with the direction bit set appropriately in the policies. Our // update will be created with the disabled bit set if startEnabled is false. func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, - pubkey *btcec.PublicKey, startEnabled bool) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { + pubkey *btcec.PublicKey, startEnabled bool) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1) { var ( pubkey1 [33]byte @@ -99,18 +99,18 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, // bit. dir2 |= lnwire.ChanUpdateDirection - return &models.ChannelEdgeInfo{ + return &models.ChannelEdgeInfo1{ ChannelPoint: channel.FundingOutpoint, NodeKey1Bytes: pubkey1, NodeKey2Bytes: pubkey2, }, - &models.ChannelEdgePolicy{ + &models.ChannelEdgePolicy1{ ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir1, LastUpdate: time.Now(), SigBytes: testSigBytes, }, - &models.ChannelEdgePolicy{ + &models.ChannelEdgePolicy1{ ChannelID: channel.ShortChanID().ToUint64(), ChannelFlags: dir2, LastUpdate: time.Now(), @@ -121,12 +121,12 @@ func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, type mockGraph struct { mu sync.Mutex channels []*channeldb.OpenChannel - chanInfos map[wire.OutPoint]*models.ChannelEdgeInfo - chanPols1 map[wire.OutPoint]*models.ChannelEdgePolicy - chanPols2 map[wire.OutPoint]*models.ChannelEdgePolicy + chanInfos map[wire.OutPoint]*models.ChannelEdgeInfo1 + chanPols1 map[wire.OutPoint]*models.ChannelEdgePolicy1 + chanPols2 map[wire.OutPoint]*models.ChannelEdgePolicy1 sidToCid map[lnwire.ShortChannelID]wire.OutPoint - updates chan *lnwire.ChannelUpdate + updates chan *lnwire.ChannelUpdate1 } func newMockGraph(t *testing.T, numChannels int, @@ -134,11 +134,11 @@ func newMockGraph(t *testing.T, numChannels int, g := &mockGraph{ channels: make([]*channeldb.OpenChannel, 0, numChannels), - chanInfos: make(map[wire.OutPoint]*models.ChannelEdgeInfo), - chanPols1: make(map[wire.OutPoint]*models.ChannelEdgePolicy), - chanPols2: make(map[wire.OutPoint]*models.ChannelEdgePolicy), + chanInfos: make(map[wire.OutPoint]*models.ChannelEdgeInfo1), + chanPols1: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), + chanPols2: make(map[wire.OutPoint]*models.ChannelEdgePolicy1), sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), - updates: make(chan *lnwire.ChannelUpdate, 2*numChannels), + updates: make(chan *lnwire.ChannelUpdate1, 2*numChannels), } for i := 0; i < numChannels; i++ { @@ -160,8 +160,8 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { } func (g *mockGraph) FetchChannelEdgesByOutpoint( - op *wire.OutPoint) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { + op *wire.OutPoint) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) { g.mu.Lock() defer g.mu.Unlock() @@ -177,7 +177,7 @@ func (g *mockGraph) FetchChannelEdgesByOutpoint( return info, pol1, pol2, nil } -func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate, +func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate1, op *wire.OutPoint, private bool) error { g.mu.Lock() @@ -210,7 +210,7 @@ func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate, timestamp := time.Unix(int64(update.Timestamp), 0) - policy := &models.ChannelEdgePolicy{ + policy := &models.ChannelEdgePolicy1{ ChannelID: update.ShortChannelID.ToUint64(), ChannelFlags: update.ChannelFlags, LastUpdate: timestamp, @@ -248,8 +248,8 @@ func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { } func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, - info *models.ChannelEdgeInfo, - pol1, pol2 *models.ChannelEdgePolicy) { + info *models.ChannelEdgeInfo1, + pol1, pol2 *models.ChannelEdgePolicy1) { g.mu.Lock() defer g.mu.Unlock() diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 8fc040f6f3..83a8282a69 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -12,16 +12,16 @@ import ( // function is used to transform out database structs into the corresponding wire // structs for announcing new channels to other peers, or simply syncing up a // peer's initial routing table upon connect. -func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, - chanInfo *models.ChannelEdgeInfo, - e1, e2 *models.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement, - *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { +func CreateChanAnnouncement(chanProof *models.ChannelAuthProof1, + chanInfo *models.ChannelEdgeInfo1, + e1, e2 *models.ChannelEdgePolicy1) (*lnwire.ChannelAnnouncement1, + *lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) { // First, using the parameters of the channel, along with the channel // authentication chanProof, we'll create re-create the original // authenticated channel announcement. chanID := lnwire.NewShortChanIDFromInt(chanInfo.ChannelID) - chanAnn := &lnwire.ChannelAnnouncement{ + chanAnn := &lnwire.ChannelAnnouncement1{ ShortChannelID: chanID, NodeID1: chanInfo.NodeKey1Bytes, NodeID2: chanInfo.NodeKey2Bytes, @@ -68,7 +68,7 @@ func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, // Since it's up to a node's policy as to whether they advertise the // edge in a direction, we don't create an advertisement if the edge is // nil. - var edge1Ann, edge2Ann *lnwire.ChannelUpdate + var edge1Ann, edge2Ann *lnwire.ChannelUpdate1 if e1 != nil { edge1Ann, err = ChannelUpdateFromEdge(chanInfo, e1) if err != nil { diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index 92a6c74ce0..844a9b1d4d 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -24,7 +24,7 @@ func TestCreateChanAnnouncement(t *testing.T) { t.Fatalf("unable to encode features: %v", err) } - expChanAnn := &lnwire.ChannelAnnouncement{ + expChanAnn := &lnwire.ChannelAnnouncement1{ ChainHash: chainhash.Hash{0x1}, ShortChannelID: lnwire.ShortChannelID{BlockHeight: 1}, NodeID1: key, @@ -39,13 +39,13 @@ func TestCreateChanAnnouncement(t *testing.T) { ExtraOpaqueData: []byte{0x1}, } - chanProof := &models.ChannelAuthProof{ + chanProof := &models.ChannelAuthProof1{ NodeSig1Bytes: expChanAnn.NodeSig1.ToSignatureBytes(), NodeSig2Bytes: expChanAnn.NodeSig2.ToSignatureBytes(), BitcoinSig1Bytes: expChanAnn.BitcoinSig1.ToSignatureBytes(), BitcoinSig2Bytes: expChanAnn.BitcoinSig2.ToSignatureBytes(), } - chanInfo := &models.ChannelEdgeInfo{ + chanInfo := &models.ChannelEdgeInfo1{ ChainHash: expChanAnn.ChainHash, ChannelID: expChanAnn.ShortChannelID.ToUint64(), ChannelPoint: wire.OutPoint{Index: 1}, diff --git a/netann/channel_update.go b/netann/channel_update.go index b93deb1d0c..8faa34cb2f 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -18,12 +18,12 @@ var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate") // ChannelUpdateModifier is a closure that makes in-place modifications to an // lnwire.ChannelUpdate. -type ChannelUpdateModifier func(*lnwire.ChannelUpdate) +type ChannelUpdateModifier func(*lnwire.ChannelUpdate1) // ChanUpdSetDisable is a functional option that sets the disabled channel flag // if disabled is true, and clears the bit otherwise. func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { - return func(update *lnwire.ChannelUpdate) { + return func(update *lnwire.ChannelUpdate1) { if disabled { // Set the bit responsible for marking a channel as // disabled. @@ -39,7 +39,7 @@ func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier { // ChanUpdSetTimestamp is a functional option that sets the timestamp of the // update to the current time, or increments it if the timestamp is already in // the future. -func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate) { +func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) { newTimestamp := uint32(time.Now().Unix()) if newTimestamp <= update.Timestamp { // Increment the prior value to ensure the timestamp @@ -57,7 +57,7 @@ func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate) { // // NOTE: This method modifies the given update. func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, - update *lnwire.ChannelUpdate, mods ...ChannelUpdateModifier) error { + update *lnwire.ChannelUpdate1, mods ...ChannelUpdateModifier) error { // Apply the requested changes to the channel update. for _, modifier := range mods { @@ -84,12 +84,12 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, - info *models.ChannelEdgeInfo, - policies ...*models.ChannelEdgePolicy) ( - *lnwire.ChannelUpdate, error) { + info *models.ChannelEdgeInfo1, + policies ...*models.ChannelEdgePolicy1) ( + *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. - owner := func(edge *models.ChannelEdgePolicy) []byte { + owner := func(edge *models.ChannelEdgePolicy1) []byte { var pubKey *btcec.PublicKey if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { pubKey, _ = info.NodeKey1() @@ -117,10 +117,10 @@ func ExtractChannelUpdate(ownerPubKey []byte, // UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the // given edge info and policy. -func UnsignedChannelUpdateFromEdge(info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy) *lnwire.ChannelUpdate { +func UnsignedChannelUpdateFromEdge(info *models.ChannelEdgeInfo1, + policy *models.ChannelEdgePolicy1) *lnwire.ChannelUpdate1 { - return &lnwire.ChannelUpdate{ + return &lnwire.ChannelUpdate1{ ChainHash: info.ChainHash, ShortChannelID: lnwire.NewShortChanIDFromInt(policy.ChannelID), Timestamp: uint32(policy.LastUpdate.Unix()), @@ -137,8 +137,8 @@ func UnsignedChannelUpdateFromEdge(info *models.ChannelEdgeInfo, // ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given edge // info and policy. -func ChannelUpdateFromEdge(info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy) (*lnwire.ChannelUpdate, error) { +func ChannelUpdateFromEdge(info *models.ChannelEdgeInfo1, + policy *models.ChannelEdgePolicy1) (*lnwire.ChannelUpdate1, error) { update := UnsignedChannelUpdateFromEdge(info, policy) diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go index 7af51effc0..689b48caba 100644 --- a/netann/channel_update_test.go +++ b/netann/channel_update_test.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" - "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -111,7 +110,7 @@ func TestUpdateDisableFlag(t *testing.T) { // Create the initial update, the only fields we are // concerned with in this test are the timestamp and the // channel flags. - ogUpdate := &lnwire.ChannelUpdate{ + ogUpdate := &lnwire.ChannelUpdate1{ Timestamp: uint32(tc.startTime.Unix()), } if !tc.startEnabled { @@ -122,7 +121,7 @@ func TestUpdateDisableFlag(t *testing.T) { // the original. UpdateDisableFlag will mutate the // passed channel update, so we keep the old one to test // against. - newUpdate := &lnwire.ChannelUpdate{ + newUpdate := &lnwire.ChannelUpdate1{ Timestamp: ogUpdate.Timestamp, ChannelFlags: ogUpdate.ChannelFlags, } @@ -182,9 +181,7 @@ func TestUpdateDisableFlag(t *testing.T) { // Finally, validate the signature using the router's // verification logic. - err = graph.VerifyChannelUpdateSignature( - newUpdate, pubKey, - ) + err = newUpdate.VerifySig(pubKey) if err != nil { t.Fatalf("channel update failed to "+ "validate: %v", err) diff --git a/netann/interface.go b/netann/interface.go index d6cdb46d0e..8fd5eaf275 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -19,6 +19,6 @@ type DB interface { type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. - FetchChannelEdgesByOutpoint(*wire.OutPoint) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + FetchChannelEdgesByOutpoint(*wire.OutPoint) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) } diff --git a/netann/sign.go b/netann/sign.go index 4295331360..0c7612eac4 100644 --- a/netann/sign.go +++ b/netann/sign.go @@ -20,9 +20,9 @@ func SignAnnouncement(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator, ) switch m := msg.(type) { - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: data, err = m.DataToSign() - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: data, err = m.DataToSign() case *lnwire.NodeAnnouncement: data, err = m.DataToSign() diff --git a/peer/brontide.go b/peer/brontide.go index 25c7cea6f2..a609e9abb9 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -301,7 +301,7 @@ type Config struct { // FetchLastChanUpdate fetches our latest channel update for a target // channel. - FetchLastChanUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, + FetchLastChanUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) // FundingManager is an implementation of the funding.Controller interface. @@ -1027,7 +1027,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( // // TODO(roasbeef): can add helper method to get policy for // particular channel. - var selfPolicy *models.ChannelEdgePolicy + var selfPolicy *models.ChannelEdgePolicy1 if info != nil && bytes.Equal(info.NodeKey1Bytes[:], p.cfg.ServerPubKey[:]) { @@ -1954,10 +1954,10 @@ out: nextMsg.MsgType()) } - case *lnwire.ChannelUpdate, - *lnwire.ChannelAnnouncement, + case *lnwire.ChannelUpdate1, + *lnwire.ChannelAnnouncement1, *lnwire.NodeAnnouncement, - *lnwire.AnnounceSignatures, + *lnwire.AnnounceSignatures1, *lnwire.GossipTimestampRange, *lnwire.QueryShortChanIDs, *lnwire.QueryChannelRange, @@ -2215,15 +2215,15 @@ func messageSummary(msg lnwire.Message) string { case *lnwire.Error: return fmt.Sprintf("%v", msg.Error()) - case *lnwire.AnnounceSignatures: + case *lnwire.AnnounceSignatures1: return fmt.Sprintf("chan_id=%v, short_chan_id=%v", msg.ChannelID, msg.ShortChannelID.ToUint64()) - case *lnwire.ChannelAnnouncement: + case *lnwire.ChannelAnnouncement1: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v", msg.ChainHash, msg.ShortChannelID.ToUint64()) - case *lnwire.ChannelUpdate: + case *lnwire.ChannelUpdate1: return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+ "mflags=%v, cflags=%v, update_time=%v", msg.ChainHash, msg.ShortChannelID.ToUint64(), msg.MessageFlags, diff --git a/peer/test_utils.go b/peer/test_utils.go index e0ae29be8b..0575acca55 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -611,7 +611,7 @@ func createTestPeer(t *testing.T) *peerTestCtx { IsChannelActive: func(lnwire.ChannelID) bool { return true }, - ApplyChannelUpdate: func(*lnwire.ChannelUpdate, + ApplyChannelUpdate: func(*lnwire.ChannelUpdate1, *wire.OutPoint, bool) error { return nil @@ -719,9 +719,9 @@ func createTestPeer(t *testing.T) *peerTestCtx { }, PongBuf: make([]byte, lnwire.MaxPongBytes), FetchLastChanUpdate: func(chanID lnwire.ShortChannelID, - ) (*lnwire.ChannelUpdate, error) { + ) (*lnwire.ChannelUpdate1, error) { - return &lnwire.ChannelUpdate{}, nil + return &lnwire.ChannelUpdate1{}, nil }, } diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index 19a6edcaa6..6ccd5e7437 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -42,8 +42,8 @@ type BuildBlindedPathCfg struct { // FetchChannelEdgesByID attempts to look up the two directed edges for // the channel identified by the channel ID. - FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + FetchChannelEdgesByID func(chanID uint64) (*models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) // FetchOurOpenChannels fetches this node's set of open channels. FetchOurOpenChannels func() ([]*channeldb.OpenChannel, error) @@ -634,7 +634,7 @@ func getNodeChannelPolicy(cfg *BuildBlindedPathCfg, chanID uint64, // node in question. We know the update is the correct one if the // "ToNode" for the fetched policy is _not_ equal to the node ID in // question. - var policy *models.ChannelEdgePolicy + var policy *models.ChannelEdgePolicy1 switch { case update1 != nil && !bytes.Equal(update1.ToNode[:], nodeID[:]): policy = update1 diff --git a/routing/blindedpath/blinded_path_test.go b/routing/blindedpath/blinded_path_test.go index 51d028eafb..d593b034fb 100644 --- a/routing/blindedpath/blinded_path_test.go +++ b/routing/blindedpath/blinded_path_test.go @@ -580,7 +580,7 @@ func TestBuildBlindedPath(t *testing.T) { }, } - realPolicies := map[uint64]*models.ChannelEdgePolicy{ + realPolicies := map[uint64]*models.ChannelEdgePolicy1{ chanCB: { ChannelID: chanCB, ToNode: bob, @@ -598,8 +598,8 @@ func TestBuildBlindedPath(t *testing.T) { return []*route.Route{realRoute}, nil }, FetchChannelEdgesByID: func(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { return nil, realPolicies[chanID], nil, nil }, @@ -748,7 +748,7 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { }, } - realPolicies := map[uint64]*models.ChannelEdgePolicy{ + realPolicies := map[uint64]*models.ChannelEdgePolicy1{ chanCB: { ChannelID: chanCB, ToNode: bob, @@ -766,8 +766,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { return []*route.Route{realRoute}, nil }, FetchChannelEdgesByID: func(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { policy, ok := realPolicies[chanID] if !ok { @@ -937,8 +937,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { nil }, FetchChannelEdgesByID: func(chanID uint64) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgeInfo1, *models.ChannelEdgePolicy1, + *models.ChannelEdgePolicy1, error) { // Force the call to error for the first 2 channels. if errCount < 2 { diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index f0f9b88de0..55a9af095d 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -32,8 +32,8 @@ type Manager struct { // ForAllOutgoingChannels is required to iterate over all our local // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy) error) error + *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1) error) error // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. @@ -74,8 +74,8 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // otherwise we'll collect them all. err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, - info *models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy) error { + info *models.ChannelEdgeInfo1, + edge *models.ChannelEdgePolicy1) error { // If we have a channel filter, and this channel isn't a part // of it, then we'll skip it. @@ -182,7 +182,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // updateEdge updates the given edge with the new schema. func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint, - edge *models.ChannelEdgePolicy, + edge *models.ChannelEdgePolicy1, newSchema routing.ChannelPolicy) error { // Update forwarding fee scheme and required time lock delta. diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index 7594eef04a..0054396d08 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -22,7 +22,7 @@ func TestManager(t *testing.T) { t.Parallel() type channel struct { - edgeInfo *models.ChannelEdgeInfo + edgeInfo *models.ChannelEdgeInfo1 } var ( @@ -44,7 +44,7 @@ func TestManager(t *testing.T) { MaxHTLC: 5000, } - currentPolicy := models.ChannelEdgePolicy{ + currentPolicy := models.ChannelEdgePolicy1{ MinHTLC: minHTLC, MessageFlags: lnwire.ChanUpdateRequiredMaxHtlc, } @@ -107,8 +107,8 @@ func TestManager(t *testing.T) { } forAllOutgoingChannels := func(cb func(kvdb.RTx, - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy) error) error { + *models.ChannelEdgeInfo1, + *models.ChannelEdgePolicy1) error) error { for _, c := range channelSet { if err := cb(nil, c.edgeInfo, ¤tPolicy); err != nil { @@ -152,7 +152,7 @@ func TestManager(t *testing.T) { tests := []struct { name string - currentPolicy models.ChannelEdgePolicy + currentPolicy models.ChannelEdgePolicy1 newPolicy routing.ChannelPolicy channelSet []channel specifiedChanPoints []wire.OutPoint @@ -166,7 +166,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &models.ChannelEdgeInfo{ + edgeInfo: &models.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -183,7 +183,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &models.ChannelEdgeInfo{ + edgeInfo: &models.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -200,7 +200,7 @@ func TestManager(t *testing.T) { newPolicy: newPolicy, channelSet: []channel{ { - edgeInfo: &models.ChannelEdgeInfo{ + edgeInfo: &models.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, @@ -221,7 +221,7 @@ func TestManager(t *testing.T) { newPolicy: noMaxHtlcPolicy, channelSet: []channel{ { - edgeInfo: &models.ChannelEdgeInfo{ + edgeInfo: &models.ChannelEdgeInfo1{ Capacity: chanCap, ChannelPoint: chanPointValid, }, diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index 27391d53e2..4a0f738715 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -197,7 +197,7 @@ func TestMissionControl(t *testing.T) { // A node level failure should bring probability of all known channels // back to zero. - ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate{})) + ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{})) ctx.expectP(1000, 0) // Check whether history snapshot looks sane. @@ -219,14 +219,14 @@ func TestMissionControlChannelUpdate(t *testing.T) { // Report a policy related failure. Because it is the first, we don't // expect a penalty. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), + 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), ) ctx.expectP(100, testAprioriHopProbability) // Report another failure for the same channel. We expect it to be // pruned. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), + 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), ) ctx.expectP(100, 0) } diff --git a/routing/mock_test.go b/routing/mock_test.go index 306c182107..e0dab5e798 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -182,7 +182,7 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, return r, nil } -func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, +func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate1, _ *btcec.PublicKey, _ *models.CachedEdgePolicy) bool { return false @@ -702,7 +702,7 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, return args.Get(0).(*route.Route), args.Error(1) } -func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, +func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) diff --git a/routing/pathfind.go b/routing/pathfind.go index 01325c2238..a415bb9960 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -87,7 +87,7 @@ var ( ) // edgePolicyWithSource is a helper struct to keep track of the source node -// of a channel edge. ChannelEdgePolicy only contains to destination node +// of a channel edge. ChannelEdgePolicy1 only contains to destination node // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex @@ -1127,7 +1127,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // destination features were provided. This is fine though, since our // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within - // findPath, and avoid using ChannelEdgePolicy altogether. + // findPath, and avoid using ChannelEdgePolicy1 altogether. pathEdges[len(pathEdges)-1].policy.ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 38942a31d2..8af7dd9529 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -98,7 +98,7 @@ var ( _ = testSScalar.SetByteSlice(testSBytes) testSig = ecdsa.NewSignature(testRScalar, testSScalar) - testAuthProof = models.ChannelAuthProof{ + testAuthProof = models.ChannelAuthProof1{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), BitcoinSig1Bytes: testSig.Serialize(), @@ -337,7 +337,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( // We first insert the existence of the edge between the two // nodes. - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: edge.ChannelID, AuthProof: &testAuthProof, ChannelPoint: fundingPoint, @@ -368,7 +368,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( targetNode = edgeInfo.NodeKey2Bytes } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), ChannelFlags: channelFlags, @@ -652,7 +652,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, // We first insert the existence of the edge between the two // nodes. - edgeInfo := models.ChannelEdgeInfo{ + edgeInfo := models.ChannelEdgeInfo1{ ChannelID: channelID, AuthProof: &testAuthProof, ChannelPoint: *fundingPoint, @@ -692,7 +692,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, channelFlags |= lnwire.ChanUpdateDisabled } - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, @@ -722,7 +722,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } channelFlags |= lnwire.ChanUpdateDirection - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, ChannelFlags: channelFlags, diff --git a/routing/payment_session.go b/routing/payment_session.go index 00b4ab70ed..88c45e0151 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -8,7 +8,6 @@ import ( "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" - "github.com/lightningnetwork/lnd/graph" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -145,8 +144,8 @@ type PaymentSession interface { // (private channels) and applies the update from the message. Returns // a boolean to indicate whether the update has been applied without // error. - UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, - policy *models.CachedEdgePolicy) bool + UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, + pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query // the ephemeral channel edge policy for additional edges. Returns a nil @@ -432,11 +431,11 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // validates the message signature and checks it's up to date, then applies the // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. -func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, +func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1, pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool { // Validate the message signature. - if err := graph.VerifyChannelUpdateSignature(msg, pubKey); err != nil { + if err := msg.VerifySig(pubKey); err != nil { log.Errorf( "Unable to validate channel update signature: %v", err, ) diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index f6873aa752..86f0e33268 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -147,7 +147,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { ) // Create the channel update message and sign. - msg := &lnwire.ChannelUpdate{ + msg := &lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID), Timestamp: uint32(time.Now().Unix()), BaseFee: newFeeBaseMSat, diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index bf7d6d3edd..68b527e5a9 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -164,7 +164,7 @@ var resultTestCases = []resultTestCase{ name: "fail expiry too soon", route: &routeFourHop, failureSrcIdx: 3, - failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate{}), + failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ @@ -266,8 +266,9 @@ var resultTestCases = []resultTestCase{ name: "fail fee insufficient intermediate", route: &routeFourHop, failureSrcIdx: 2, - failure: lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate{}), - + failure: lnwire.NewFeeInsufficient( + 0, lnwire.ChannelUpdate1{}, + ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ getTestPair(0, 1): { diff --git a/routing/router.go b/routing/router.go index 0b6a9beacd..e9282db6c1 100644 --- a/routing/router.go +++ b/routing/router.go @@ -284,7 +284,7 @@ type Config struct { // ApplyChannelUpdate can be called to apply a new channel update to the // graph that we received from a payment failure. - ApplyChannelUpdate func(msg *lnwire.ChannelUpdate) bool + ApplyChannelUpdate func(msg *lnwire.ChannelUpdate1) bool // ClosedSCIDs is used by the router to fetch closed channels. // @@ -1270,9 +1270,9 @@ func (r *ChannelRouter) sendPayment(ctx context.Context, // extractChannelUpdate examines the error and extracts the channel update. func (r *ChannelRouter) extractChannelUpdate( - failure lnwire.FailureMessage) *lnwire.ChannelUpdate { + failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 { - var update *lnwire.ChannelUpdate + var update *lnwire.ChannelUpdate1 switch onionErr := failure.(type) { case *lnwire.FailExpiryTooSoon: update = &onionErr.Update diff --git a/routing/router_test.go b/routing/router_test.go index 411a6c81cd..0f7528455a 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -220,7 +220,7 @@ func createTestCtxFromFile(t *testing.T, // Add valid signature to channel update simulated as error received from the // network. func signErrChanUpdate(t *testing.T, key *btcec.PrivateKey, - errChanUpdate *lnwire.ChannelUpdate) { + errChanUpdate *lnwire.ChannelUpdate1) { chanUpdateMsg, err := errChanUpdate.DataToSign() require.NoError(t, err, "failed to retrieve data to sign") @@ -485,7 +485,7 @@ func TestChannelUpdateValidation(t *testing.T) { // Set up a channel update message with an invalid signature to be // returned to the sender. var invalidSignature lnwire.Sig - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ Signature: invalidSignature, FeeRate: 500, ShortChannelID: lnwire.NewShortChanIDFromInt(1), @@ -590,7 +590,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { ) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt( songokuSophonChanID, ), @@ -709,7 +709,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // Prepare an error update for the private channel, with twice the // original fee. updatedFeeBaseMSat := feeBaseMSat * 2 - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(privateChannelID), Timestamp: uint32(testTime.Add(time.Minute).Unix()), BaseFee: updatedFeeBaseMSat, @@ -835,7 +835,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // Prepare an error update for the private channel. The updated fee // will exceeds the feeLimit. updatedFeeBaseMSat := feeBaseMSat + uint32(feeLimit) - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(privateChannelID), Timestamp: uint32(testTime.Add(time.Minute).Unix()), BaseFee: updatedFeeBaseMSat, @@ -936,7 +936,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { _, _, edgeUpdateToFail, err := ctx.graph.FetchChannelEdgesByID(chanID) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate{ + errChanUpdate := lnwire.ChannelUpdate1{ ShortChannelID: lnwire.NewShortChanIDFromInt(chanID), Timestamp: uint32(edgeUpdateToFail.LastUpdate.Unix()), MessageFlags: edgeUpdateToFail.MessageFlags, @@ -1399,7 +1399,7 @@ func TestSendToRouteStructuredError(t *testing.T) { testCases := map[int]lnwire.FailureMessage{ finalHopIndex: lnwire.NewFailIncorrectDetails(payAmt, 100), 1: &lnwire.FailFeeInsufficient{ - Update: lnwire.ChannelUpdate{}, + Update: lnwire.ChannelUpdate1{}, }, } @@ -2727,7 +2727,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { ) require.NoError(t, err, "unable to create channel edge") - edge := &models.ChannelEdgeInfo{ + edge := &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), NodeKey1Bytes: pub1, NodeKey2Bytes: pub2, @@ -2739,7 +2739,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { // We must add the edge policy to be able to use the edge for route // finding. - edgePolicy := &models.ChannelEdgePolicy{ + edgePolicy := &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2754,7 +2754,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) // Create edge in the other direction as well. - edgePolicy = &models.ChannelEdgePolicy{ + edgePolicy = &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2806,7 +2806,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { 10000, 510) require.NoError(t, err, "unable to create channel edge") - edge = &models.ChannelEdgeInfo{ + edge = &models.ChannelEdgeInfo1{ ChannelID: chanID.ToUint64(), AuthProof: nil, } @@ -2817,7 +2817,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.AddChannelEdge(edge)) - edgePolicy = &models.ChannelEdgePolicy{ + edgePolicy = &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2831,7 +2831,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { require.NoError(t, ctx.graph.UpdateEdgePolicy(edgePolicy)) - edgePolicy = &models.ChannelEdgePolicy{ + edgePolicy = &models.ChannelEdgePolicy1{ SigBytes: testSig.Serialize(), ChannelID: edge.ChannelID, LastUpdate: testTime, @@ -2928,12 +2928,12 @@ func createDummyLightningPayment(t *testing.T, type mockGraphBuilder struct { rejectUpdate bool - updateEdge func(update *models.ChannelEdgePolicy) error + updateEdge func(update *models.ChannelEdgePolicy1) error } func newMockGraphBuilder(graph graph.DB) *mockGraphBuilder { return &mockGraphBuilder{ - updateEdge: func(update *models.ChannelEdgePolicy) error { + updateEdge: func(update *models.ChannelEdgePolicy1) error { return graph.UpdateEdgePolicy(update) }, } @@ -2943,12 +2943,12 @@ func (m *mockGraphBuilder) setNextReject(reject bool) { m.rejectUpdate = reject } -func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate) bool { +func (m *mockGraphBuilder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool { if m.rejectUpdate { return false } - err := m.updateEdge(&models.ChannelEdgePolicy{ + err := m.updateEdge(&models.ChannelEdgePolicy1{ SigBytes: msg.Signature.ToSignatureBytes(), ChannelID: msg.ShortChannelID.ToUint64(), LastUpdate: time.Unix(int64(msg.Timestamp), 0), diff --git a/rpcserver.go b/rpcserver.go index cd2198e288..cdd4a3b4c8 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6220,8 +6220,8 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // Next, for each active channel we know of within the graph, create a // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. - err = graph.ForEachChannel(func(edgeInfo *models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy) error { + err = graph.ForEachChannel(func(edgeInfo *models.ChannelEdgeInfo1, + c1, c2 *models.ChannelEdgePolicy1) error { // Do not include unannounced channels unless specifically // requested. Unannounced channels include both private channels as @@ -6293,8 +6293,8 @@ func extractInboundFeeSafe(data lnwire.ExtraOpaqueData) lnwire.Fee { return inboundFee } -func marshalDBEdge(edgeInfo *models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy) *lnrpc.ChannelEdge { +func marshalDBEdge(edgeInfo *models.ChannelEdgeInfo1, + c1, c2 *models.ChannelEdgePolicy1) *lnrpc.ChannelEdge { // Make sure the policies match the node they belong to. c1 should point // to the policy for NodeKey1, and c2 for NodeKey2. @@ -6337,7 +6337,7 @@ func marshalDBEdge(edgeInfo *models.ChannelEdgeInfo, } func marshalDBRoutingPolicy( - policy *models.ChannelEdgePolicy) *lnrpc.RoutingPolicy { + policy *models.ChannelEdgePolicy1) *lnrpc.RoutingPolicy { disabled := policy.ChannelFlags&lnwire.ChanUpdateDisabled != 0 @@ -6427,8 +6427,8 @@ func (r *rpcServer) GetChanInfo(_ context.Context, graph := r.server.graphDB var ( - edgeInfo *models.ChannelEdgeInfo - edge1, edge2 *models.ChannelEdgePolicy + edgeInfo *models.ChannelEdgeInfo1 + edge1, edge2 *models.ChannelEdgePolicy1 err error ) @@ -6497,8 +6497,8 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, ) err = graph.ForEachNodeChannel(node.PubKeyBytes, - func(_ kvdb.RTx, edge *models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy) error { + func(_ kvdb.RTx, edge *models.ChannelEdgeInfo1, + c1, c2 *models.ChannelEdgePolicy1) error { numChannels++ totalCapacity += edge.Capacity @@ -7156,8 +7156,8 @@ func (r *rpcServer) FeeReport(ctx context.Context, var feeReports []*lnrpc.ChannelFeeReport err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, - func(_ kvdb.RTx, chanInfo *models.ChannelEdgeInfo, - edgePolicy, _ *models.ChannelEdgePolicy) error { + func(_ kvdb.RTx, chanInfo *models.ChannelEdgeInfo1, + edgePolicy, _ *models.ChannelEdgePolicy1) error { // Self node should always have policies for its // channels. diff --git a/server.go b/server.go index 02161bf73b..e6843d61fa 100644 --- a/server.go +++ b/server.go @@ -992,6 +992,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, AssumeChannelValid: cfg.Routing.AssumeChannelValid, StrictZombiePruning: strictPruning, IsAlias: aliasmgr.IsAlias, + FetchTxBySCID: s.fetchTxBySCID, }) if err != nil { return nil, fmt.Errorf("can't create graph builder: %w", err) @@ -1063,6 +1064,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, GetAlias: s.aliasMgr.GetPeerAlias, FindChannel: s.findChannel, IsStillZombieChannel: s.graphBuilder.IsZombieChannel, + FetchTxBySCID: s.fetchTxBySCID, }, nodeKeyDesc) //nolint:lll @@ -1285,7 +1287,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Wrap the DeleteChannelEdges method so that the funding manager can // use it without depending on several layers of indirection. deleteAliasEdge := func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy, error) { + *models.ChannelEdgePolicy1, error) { info, e1, e2, err := s.graphDB.FetchChannelEdgesByID( scid.ToUint64(), @@ -1304,7 +1306,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, var ourKey [33]byte copy(ourKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) - var ourPolicy *models.ChannelEdgePolicy + var ourPolicy *models.ChannelEdgePolicy1 if info != nil && info.NodeKey1Bytes == ourKey { ourPolicy = e1 } else { @@ -1736,7 +1738,7 @@ func (s *server) UpdateRoutingConfig(cfg *routing.MissionControlConfig) { // signAliasUpdate takes a ChannelUpdate and returns the signature. This is // used for option_scid_alias channels where the ChannelUpdate to be sent back // may differ from what is on disk. -func (s *server) signAliasUpdate(u *lnwire.ChannelUpdate) (*ecdsa.Signature, +func (s *server) signAliasUpdate(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, error) { data, err := u.DataToSign() @@ -3223,8 +3225,8 @@ func (s *server) establishPersistentConnections() error { selfPub := s.identityECDH.PubKey().SerializeCompressed() err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( tx kvdb.RTx, - chanInfo *models.ChannelEdgeInfo, - policy, _ *models.ChannelEdgePolicy) error { + chanInfo *models.ChannelEdgeInfo1, + policy, _ *models.ChannelEdgePolicy1) error { // If the remote party has announced the channel to us, but we // haven't yet, then we won't have a policy. However, we don't @@ -4755,10 +4757,10 @@ func (s *server) fetchNodeAdvertisedAddrs(pub *btcec.PublicKey) ([]net.Addr, err // fetchLastChanUpdate returns a function which is able to retrieve our latest // channel update for a target channel. func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( - *lnwire.ChannelUpdate, error) { + *lnwire.ChannelUpdate1, error) { ourPubKey := s.identityECDH.PubKey().SerializeCompressed() - return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { info, edge1, edge2, err := s.graphBuilder.GetChannelByID(cid) if err != nil { return nil, err @@ -4773,7 +4775,7 @@ func (s *server) fetchLastChanUpdate() func(lnwire.ShortChannelID) ( // applyChannelUpdate applies the channel update to the different sub-systems of // the server. The useAlias boolean denotes whether or not to send an alias in // place of the real SCID. -func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate, +func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate1, op *wire.OutPoint, useAlias bool) error { var ( @@ -4832,6 +4834,72 @@ func (s *server) SendCustomMessage(peerPub [33]byte, msgType lnwire.MessageType, return peer.SendMessageLazy(true, msg) } +// fetchFundingTxWrapper is a wrapper around fetchFundingTx, except that it +// will exit if the router has stopped. +func (s *server) fetchTxBySCID(chanID *lnwire.ShortChannelID, + quit chan struct{}) (*wire.MsgTx, error) { + + txChan := make(chan *wire.MsgTx, 1) + errChan := make(chan error, 1) + + go func() { + tx, err := s.fetchFundingTx(chanID) + if err != nil { + errChan <- err + return + } + + txChan <- tx + }() + + select { + case tx := <-txChan: + return tx, nil + + case err := <-errChan: + return nil, err + + case <-quit: + return nil, fmt.Errorf("subsystem shutting down") + + case <-s.quit: + return nil, ErrServerShuttingDown + } +} + +// fetchFundingTx returns the funding transaction identified by the passed +// short channel ID. +// +// TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to +// later use getblocktxn). +func (s *server) fetchFundingTx( + chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { + + // First fetch the block hash by the block number encoded, then use + // that hash to fetch the block itself. + blockNum := int64(chanID.BlockHeight) + blockHash, err := s.cc.ChainIO.GetBlockHash(blockNum) + if err != nil { + return nil, err + } + fundingBlock, err := s.cc.ChainIO.GetBlock(blockHash) + if err != nil { + return nil, err + } + + // As a sanity check, ensure that the advertised transaction index is + // within the bounds of the total number of transactions within a + // block. + numTxns := uint32(len(fundingBlock.Transactions)) + if chanID.TxIndex > numTxns-1 { + return nil, fmt.Errorf("tx_index=#%v "+ + "is out of range (max_index=%v), network_chan_id=%v", + chanID.TxIndex, numTxns-1, chanID) + } + + return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil +} + // newSweepPkScriptGen creates closure that generates a new public key script // which should be used to sweep any funds into the on-chain wallet. // Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash