Skip to content

Commit

Permalink
Merge pull request #8911 from ellemouton/reduceMCRouteEncoding
Browse files Browse the repository at this point in the history
routing+channeldb: use a more minimal encoding for MC routes
  • Loading branch information
ellemouton authored Oct 1, 2024
2 parents 1acc839 + 34303e7 commit 75eaaf7
Show file tree
Hide file tree
Showing 17 changed files with 1,998 additions and 136 deletions.
5 changes: 5 additions & 0 deletions channeldb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/migration29"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/invoices"
Expand Down Expand Up @@ -286,6 +287,10 @@ var (
number: 31,
migration: migration31.DeleteLastPublishedTxTLB,
},
{
number: 32,
migration: migration32.MigrateMCRouteSerialisation,
},
}

// optionalVersions stores all optional migrations that are applied
Expand Down
2 changes: 2 additions & 0 deletions channeldb/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/migration24"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
Expand Down Expand Up @@ -42,5 +43,6 @@ func UseLogger(logger btclog.Logger) {
migration24.UseLogger(logger)
migration30.UseLogger(logger)
migration31.UseLogger(logger)
migration32.UseLogger(logger)
kvdb.UseLogger(logger)
}
263 changes: 263 additions & 0 deletions channeldb/migration/lnwire21/custom_records.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package lnwire

import (
"bytes"
"fmt"
"io"
"sort"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/tlv"
)

const (
// MinCustomRecordsTlvType is the minimum custom records TLV type as
// defined in BOLT 01.
MinCustomRecordsTlvType = 65536
)

// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
// which must be greater than or equal to MinCustomRecordsTlvType.
type CustomRecords map[uint64][]byte

// NewCustomRecords creates a new CustomRecords instance from a
// tlv.TypeMap.
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
// Make comparisons in unit tests easy by returning nil if the map is
// empty.
if len(tlvMap) == 0 {
return nil, nil
}

customRecords := make(CustomRecords, len(tlvMap))
for k, v := range tlvMap {
customRecords[uint64(k)] = v
}

// Validate the custom records.
err := customRecords.Validate()
if err != nil {
return nil, fmt.Errorf("custom records from tlv map "+
"validation error: %w", err)
}

return customRecords, nil
}

// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
return ParseCustomRecordsFrom(bytes.NewReader(b))
}

// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
typeMap, err := DecodeRecords(r)
if err != nil {
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
}

return NewCustomRecords(typeMap)
}

// Validate checks that all custom records are in the custom type range.
func (c CustomRecords) Validate() error {
if c == nil {
return nil
}

for key := range c {
if key < MinCustomRecordsTlvType {
return fmt.Errorf("custom records entry with TLV "+
"type below min: %d", MinCustomRecordsTlvType)
}
}

return nil
}

// Copy returns a copy of the custom records.
func (c CustomRecords) Copy() CustomRecords {
if c == nil {
return nil
}

customRecords := make(CustomRecords, len(c))
for k, v := range c {
customRecords[k] = v
}

return customRecords
}

// ExtendRecordProducers extends the given records slice with the custom
// records. The resultant records slice will be sorted if the given records
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
func (c CustomRecords) ExtendRecordProducers(
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {

// If the custom records are nil or empty, there is nothing to do.
if len(c) == 0 {
return producers, nil
}

// Validate the custom records.
err := c.Validate()
if err != nil {
return nil, err
}

// Ensure that the existing records slice TLV types are not also present
// in the custom records. If they are, the resultant extended records
// slice would erroneously contain duplicate TLV types.
for _, rp := range producers {
record := rp.Record()
recordTlvType := uint64(record.Type())

_, foundDuplicateTlvType := c[recordTlvType]
if foundDuplicateTlvType {
return nil, fmt.Errorf("custom records contains a TLV "+
"type that is already present in the "+
"existing records: %d", recordTlvType)
}
}

// Convert the custom records map to a TLV record producer slice and
// append them to the exiting records slice.
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
producers = append(producers, customRecordProducers...)

// If the records slice which was given as an argument included TLV
// values greater than or equal to the minimum custom records TLV type
// we will sort the extended records slice to ensure that it is ordered
// correctly.
SortProducers(producers)

return producers, nil
}

// RecordProducers returns a slice of record producers for the custom records.
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
// If the custom records are nil or empty, return an empty slice.
if len(c) == 0 {
return nil
}

// Convert the custom records map to a TLV record producer slice.
records := tlv.MapToRecords(c)

return RecordsAsProducers(records)
}

// Serialize serializes the custom records into a byte slice.
func (c CustomRecords) Serialize() ([]byte, error) {
records := tlv.MapToRecords(c)
return EncodeRecords(records)
}

// SerializeTo serializes the custom records into the given writer.
func (c CustomRecords) SerializeTo(w io.Writer) error {
records := tlv.MapToRecords(c)
return EncodeRecordsTo(w, records)
}

// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
return producer.Record()
}, recordProducers)

// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
tlv.SortRecords(records)

return records
}

// SortProducers sorts the given record producers by their type.
func SortProducers(producers []tlv.RecordProducer) {
sort.Slice(producers, func(i, j int) bool {
recordI := producers[i].Record()
recordJ := producers[j].Record()
return recordI.Type() < recordJ.Type()
})
}

// TlvMapToRecords converts a TLV map into a slice of records.
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
tlvMapGeneric := make(map[uint64][]byte)
for k, v := range tlvMap {
tlvMapGeneric[uint64(k)] = v
}

return tlv.MapToRecords(tlvMapGeneric)
}

// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
return fn.Map(func(record tlv.Record) tlv.RecordProducer {
return &record
}, records)
}

// EncodeRecords encodes the given records into a byte slice.
func EncodeRecords(records []tlv.Record) ([]byte, error) {
var buf bytes.Buffer
if err := EncodeRecordsTo(&buf, records); err != nil {
return nil, err
}

return buf.Bytes(), nil
}

// EncodeRecordsTo encodes the given records into the given writer.
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}

return tlvStream.Encode(w)
}

// DecodeRecords decodes the given byte slice into the given records and returns
// the rest as a TLV type map.
func DecodeRecords(r io.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {

tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}

return tlvStream.DecodeWithParsedTypes(r)
}

// DecodeRecordsP2P decodes the given byte slice into the given records and
// returns the rest as a TLV type map. This function is identical to
// DecodeRecords except that the record size is capped at 65535.
func DecodeRecordsP2P(r *bytes.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {

tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}

return tlvStream.DecodeWithParsedTypesP2P(r)
}

// AssertUniqueTypes asserts that the given records have unique types.
func AssertUniqueTypes(r []tlv.Record) error {
seen := make(fn.Set[tlv.Type], len(r))
for _, record := range r {
t := record.Type()
if seen.Contains(t) {
return fmt.Errorf("duplicate record type: %d", t)
}
seen.Add(t)
}

return nil
}
54 changes: 54 additions & 0 deletions channeldb/migration/lnwire21/onion_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type FailCode uint16

// The currently defined onion failure types within this current version of the
// Lightning protocol.
//
//nolint:lll
const (
CodeNone FailCode = 0
CodeInvalidRealm = FlagBadOnion | 1
Expand All @@ -80,6 +82,7 @@ const (
CodeExpiryTooFar FailCode = 21
CodeInvalidOnionPayload = FlagPerm | 22
CodeMPPTimeout FailCode = 23
CodeInvalidBlinding = FlagBadOnion | FlagPerm | 24
)

// String returns the string representation of the failure code.
Expand Down Expand Up @@ -157,6 +160,9 @@ func (c FailCode) String() string {
case CodeMPPTimeout:
return "MPPTimeout"

case CodeInvalidBlinding:
return "InvalidBlinding"

default:
return "<unknown>"
}
Expand Down Expand Up @@ -571,6 +577,51 @@ func (f *FailInvalidOnionKey) Error() string {
return fmt.Sprintf("InvalidOnionKey(onion_sha=%x)", f.OnionSHA256[:])
}

// FailInvalidBlinding is returned if there has been a route blinding related
// error.
type FailInvalidBlinding struct {
OnionSHA256 [sha256.Size]byte
}

// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidBlinding) Code() FailCode {
return CodeInvalidBlinding
}

// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidBlinding) Error() string {
return f.Code().String()
}

// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Decode(r io.Reader, _ uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}

// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Encode(w *bytes.Buffer, _ uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}

// NewInvalidBlinding creates new instance of FailInvalidBlinding.
func NewInvalidBlinding(onion []byte) *FailInvalidBlinding {
// The spec allows empty onion hashes for invalid blinding, so we only
// include our onion hash if it's provided.
if onion == nil {
return &FailInvalidBlinding{}
}

return &FailInvalidBlinding{OnionSHA256: sha256.Sum256(onion)}
}

// parseChannelUpdateCompatabilityMode will attempt to parse a channel updated
// encoded into an onion error payload in two ways. First, we'll try the
// compatibility oriented version wherein we'll _skip_ the length prefixing on
Expand Down Expand Up @@ -1392,6 +1443,9 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
case CodeMPPTimeout:
return &FailMPPTimeout{}, nil

case CodeInvalidBlinding:
return &FailInvalidBlinding{}, nil

default:
return nil, errors.Errorf("unknown error code: %v", code)
}
Expand Down
Loading

0 comments on commit 75eaaf7

Please sign in to comment.