From ca4ff25925b6a700ebc2e015f5c81175e3d058c9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 21 Aug 2024 12:58:01 +0200 Subject: [PATCH] channeldb: add encoding for ChannelEdgePolicy2 Similarly to the previous commit, here we add the encoding for the new ChannelEdgePolicy2. This is done in the same was as for ChannelEdgeInfo2: - a 0xff prefix - followed by a type-byte - followed by the TLV encoding of the ChannelEdgePolicy2. --- channeldb/edge_policy.go | 205 ++++++++++++++++++++++++++++++++-- channeldb/edge_policy_test.go | 171 ++++++++++++++++++++++++++++ channeldb/graph.go | 26 ++++- 3 files changed, 392 insertions(+), 10 deletions(-) create mode 100644 channeldb/edge_policy_test.go diff --git a/channeldb/edge_policy.go b/channeldb/edge_policy.go index 04ab95411ea..07893849858 100644 --- a/channeldb/edge_policy.go +++ b/channeldb/edge_policy.go @@ -1,6 +1,7 @@ package channeldb import ( + "bufio" "bytes" "encoding/binary" "fmt" @@ -11,6 +12,30 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + EdgePolicy2MsgType = tlv.Type(0) + EdgePolicy2ToNode = tlv.Type(1) + + // chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge + // policy encoding to signal that the new style encoding which is + // prefixed with a type byte is being used instead of the legacy + // encoding which would start with 0x02 due to the fact that the + // encoding would start with a DER encoded ecdsa signature. + chanEdgePolicyNewEncodingPrefix = 0xff +) + +// edgePolicyEncoding indicates how the bytes for a channel edge policy have +// been serialised. +type edgePolicyEncodingType uint8 + +const ( + // edgePolicy2EncodingType will be used as a prefix for edge policies + // advertised using the ChannelUpdate2 message. The type indicates how + // the bytes following should be deserialized. + edgePolicy2EncodingType edgePolicyEncodingType = 0 ) func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, @@ -60,7 +85,14 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, return err } - oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) + oldPol, ok := oldEdgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", + oldEdgePolicy) + } + + oldUpdateTime := uint64(oldPol.LastUpdate.Unix()) var oldIndexKey [8 + 8]byte byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) @@ -163,7 +195,13 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, return nil, err } - return ep, nil + pol, ok := ep.(*models.ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected *models.ChannelEdgePolicy1, "+ + "got: %T", ep) + } + + return pol, nil } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, @@ -198,8 +236,56 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return edge1, edge2, nil } -func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, - to []byte) error { +func serializeChanEdgePolicy(w io.Writer, + edgePolicy models.ChannelEdgePolicy, toNode []byte) error { + + var ( + withTypeByte bool + typeByte edgePolicyEncodingType + serialize func(w io.Writer) error + ) + + switch policy := edgePolicy.(type) { + case *models.ChannelEdgePolicy1: + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy1(w, policy) + } + case *models.ChannelEdgePolicy2: + withTypeByte = true + typeByte = edgePolicy2EncodingType + + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy2(w, policy) + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgePolicy: %T", edgePolicy) + } + + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := w.Write([]byte{chanEdgePolicyNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = w.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + return serialize(w) +} + +func serializeChanEdgePolicy1(w io.Writer, + edge *models.ChannelEdgePolicy1) error { err := wire.WriteVarBytes(w, 0, edge.SigBytes) if err != nil { @@ -236,7 +322,7 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, return err } - if _, err := w.Write(to); err != nil { + if _, err := w.Write(edge.ToNode[:]); err != nil { return err } @@ -265,7 +351,36 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, return nil } -func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) { +func serializeChanEdgePolicy2(w io.Writer, + edge *models.ChannelEdgePolicy2) error { + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + var b bytes.Buffer + if err := edge.Encode(&b, 0); err != nil { + return err + } + + msg := b.Bytes() + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &edge.ToNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func deserializeChanEdgePolicy(r io.Reader) (models.ChannelEdgePolicy, + error) { + // Deserialize the policy. Note that in case an optional field is not // found, both an error and a populated policy object are returned. edge, deserializeErr := deserializeChanEdgePolicyRaw(r) @@ -278,7 +393,45 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) return edge, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1, +func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy, + error) { + + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgePolicyNewEncodingPrefix { + return deserializeChanEdgePolicy1Raw(r) + } + + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgePolicyEncodingType(scratch[0]) + switch encoding { + case edgePolicy2EncodingType: + return deserializeChanEdgePolicy2Raw(r) + + default: + return nil, fmt.Errorf("unknown edge policy encoding type: %d", + encoding) + } +} + +func deserializeChanEdgePolicy1Raw(r io.Reader) (*models.ChannelEdgePolicy1, error) { edge := &models.ChannelEdgePolicy1{} @@ -364,3 +517,41 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1, return edge, nil } + +func deserializeChanEdgePolicy2Raw(r io.Reader) (*models.ChannelEdgePolicy2, + error) { + + var ( + msgBytes []byte + toNode [33]byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + err = stream.Decode(r) + if err != nil { + return nil, err + } + + var ( + chanUpdate lnwire.ChannelUpdate2 + reader = bytes.NewReader(msgBytes) + ) + err = chanUpdate.Decode(reader, 0) + if err != nil { + return nil, err + } + + return &models.ChannelEdgePolicy2{ + ChannelUpdate2: chanUpdate, + ToNode: toNode, + }, nil +} diff --git a/channeldb/edge_policy_test.go b/channeldb/edge_policy_test.go new file mode 100644 index 00000000000..6c8b1e22174 --- /dev/null +++ b/channeldb/edge_policy_test.go @@ -0,0 +1,171 @@ +package channeldb + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEdgePolicySerialisation tests the serialisation and deserialization logic +// for models.ChannelEdgePolicy. +func TestEdgePolicySerialisation(t *testing.T) { + t.Parallel() + + mainScenario := func(info models.ChannelEdgePolicy) bool { + var ( + b bytes.Buffer + toNode = info.GetToNode() + ) + + err := serializeChanEdgePolicy(&b, info, toNode[:]) + require.NoError(t, err) + + newInfo, err := deserializeChanEdgePolicy(&b) + require.NoError(t, err) + + return assert.Equal(t, info, newInfo) + } + + tests := []struct { + name string + genValue func([]reflect.Value, *rand.Rand) + scenario any + }{ + { + name: "ChannelEdgePolicy1", + scenario: func(m models.ChannelEdgePolicy1) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + //nolint:lll + policy := &models.ChannelEdgePolicy1{ + ChannelID: r.Uint64(), + LastUpdate: time.Unix(r.Int63(), 0), + MessageFlags: lnwire.ChanUpdateMsgFlags(r.Uint32()), + ChannelFlags: lnwire.ChanUpdateChanFlags(r.Uint32()), + TimeLockDelta: uint16(r.Uint32()), + MinHTLC: lnwire.MilliSatoshi(r.Uint64()), + FeeBaseMSat: lnwire.MilliSatoshi(r.Uint64()), + FeeProportionalMillionths: lnwire.MilliSatoshi(r.Uint64()), + ExtraOpaqueData: make([]byte, 0), + } + + policy.SigBytes = make([]byte, r.Intn(80)) + _, err := r.Read(policy.SigBytes) + require.NoError(t, err) + + _, err = r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + // Sometimes add an MaxHTLC. + if r.Intn(2)%2 == 0 { + policy.MessageFlags |= + lnwire.ChanUpdateRequiredMaxHtlc + policy.MaxHTLC = lnwire.MilliSatoshi( + r.Uint64(), + ) + } else { + policy.MessageFlags ^= + lnwire.ChanUpdateRequiredMaxHtlc + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + { + name: "ChannelEdgePolicy2", + scenario: func(m models.ChannelEdgePolicy2) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + policy := &models.ChannelEdgePolicy2{ + //nolint:lll + ChannelUpdate2: lnwire.ChannelUpdate2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + }, + ToNode: [33]byte{}, + } + + policy.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( //nolint:lll + uint64(r.Int63()), + ) + policy.BlockHeight.Val = r.Uint32() + policy.HTLCMaximumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.HTLCMinimumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.CLTVExpiryDelta.Val = uint16(r.Int31()) + policy.FeeBaseMsat.Val = r.Uint32() + policy.FeeProportionalMillionths.Val = r.Uint32() //nolint:lll + + if r.Intn(2) == 0 { + policy.Direction.Val.B = true + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableOutgoing + } + + _, err := r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + config := &quick.Config{ + Values: test.genValue, + } + + err := quick.Check(test.scenario, config) + require.NoError(t, err) + }) + } +} diff --git a/channeldb/graph.go b/channeldb/graph.go index 77a65fa8ef9..30ac93a3b2c 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -303,7 +303,13 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } - channelMap[key] = edge + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", edge) + } + + channelMap[key] = e return nil }) @@ -2378,7 +2384,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node1UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node1UpdateTimestamp = e.LastUpdate } rawPolicy = edges.Get(node2Key) @@ -2393,7 +2406,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node2UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node2UpdateTimestamp = e.LastUpdate } channelsPerBlock[cid.BlockHeight] = append(