Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

routing+channeldb: use a more minimal encoding for MC routes #8911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions channeldb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/migration29"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/invoices"
Expand Down Expand Up @@ -286,6 +287,10 @@ var (
number: 31,
migration: migration31.DeleteLastPublishedTxTLB,
},
{
number: 32,
migration: migration32.MigrateMCRouteSerialisation,
},
}

// optionalVersions stores all optional migrations that are applied
Expand Down
2 changes: 2 additions & 0 deletions channeldb/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb/migration24"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
Expand Down Expand Up @@ -42,5 +43,6 @@ func UseLogger(logger btclog.Logger) {
migration24.UseLogger(logger)
migration30.UseLogger(logger)
migration31.UseLogger(logger)
migration32.UseLogger(logger)
kvdb.UseLogger(logger)
}
263 changes: 263 additions & 0 deletions channeldb/migration/lnwire21/custom_records.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package lnwire

import (
"bytes"
"fmt"
"io"
"sort"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/tlv"
)

const (
// MinCustomRecordsTlvType is the minimum custom records TLV type as
// defined in BOLT 01.
MinCustomRecordsTlvType = 65536
)

// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
// which must be greater than or equal to MinCustomRecordsTlvType.
type CustomRecords map[uint64][]byte

// NewCustomRecords creates a new CustomRecords instance from a
// tlv.TypeMap.
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
// Make comparisons in unit tests easy by returning nil if the map is
// empty.
if len(tlvMap) == 0 {
return nil, nil
}

customRecords := make(CustomRecords, len(tlvMap))
for k, v := range tlvMap {
customRecords[uint64(k)] = v
}

// Validate the custom records.
err := customRecords.Validate()
if err != nil {
return nil, fmt.Errorf("custom records from tlv map "+
"validation error: %w", err)
}

return customRecords, nil
}

// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
return ParseCustomRecordsFrom(bytes.NewReader(b))
}

// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
typeMap, err := DecodeRecords(r)
if err != nil {
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
}

return NewCustomRecords(typeMap)
}

// Validate checks that all custom records are in the custom type range.
func (c CustomRecords) Validate() error {
if c == nil {
return nil
}

for key := range c {
if key < MinCustomRecordsTlvType {
return fmt.Errorf("custom records entry with TLV "+
"type below min: %d", MinCustomRecordsTlvType)
}
}

return nil
}

// Copy returns a copy of the custom records.
func (c CustomRecords) Copy() CustomRecords {
if c == nil {
return nil
}

customRecords := make(CustomRecords, len(c))
for k, v := range c {
customRecords[k] = v
}

return customRecords
}

// ExtendRecordProducers extends the given records slice with the custom
// records. The resultant records slice will be sorted if the given records
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
func (c CustomRecords) ExtendRecordProducers(
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {

// If the custom records are nil or empty, there is nothing to do.
if len(c) == 0 {
return producers, nil
}

// Validate the custom records.
err := c.Validate()
if err != nil {
return nil, err
}

// Ensure that the existing records slice TLV types are not also present
// in the custom records. If they are, the resultant extended records
// slice would erroneously contain duplicate TLV types.
for _, rp := range producers {
record := rp.Record()
recordTlvType := uint64(record.Type())

_, foundDuplicateTlvType := c[recordTlvType]
if foundDuplicateTlvType {
return nil, fmt.Errorf("custom records contains a TLV "+
"type that is already present in the "+
"existing records: %d", recordTlvType)
}
}

// Convert the custom records map to a TLV record producer slice and
// append them to the exiting records slice.
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
producers = append(producers, customRecordProducers...)

// If the records slice which was given as an argument included TLV
// values greater than or equal to the minimum custom records TLV type
// we will sort the extended records slice to ensure that it is ordered
// correctly.
SortProducers(producers)

return producers, nil
}

// RecordProducers returns a slice of record producers for the custom records.
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
// If the custom records are nil or empty, return an empty slice.
if len(c) == 0 {
return nil
}

// Convert the custom records map to a TLV record producer slice.
records := tlv.MapToRecords(c)

return RecordsAsProducers(records)
}

// Serialize serializes the custom records into a byte slice.
func (c CustomRecords) Serialize() ([]byte, error) {
records := tlv.MapToRecords(c)
return EncodeRecords(records)
}

// SerializeTo serializes the custom records into the given writer.
func (c CustomRecords) SerializeTo(w io.Writer) error {
records := tlv.MapToRecords(c)
return EncodeRecordsTo(w, records)
}

// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
return producer.Record()
}, recordProducers)

// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
tlv.SortRecords(records)

return records
}

// SortProducers sorts the given record producers by their type.
func SortProducers(producers []tlv.RecordProducer) {
sort.Slice(producers, func(i, j int) bool {
recordI := producers[i].Record()
recordJ := producers[j].Record()
return recordI.Type() < recordJ.Type()
})
}

// TlvMapToRecords converts a TLV map into a slice of records.
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
tlvMapGeneric := make(map[uint64][]byte)
for k, v := range tlvMap {
tlvMapGeneric[uint64(k)] = v
}

return tlv.MapToRecords(tlvMapGeneric)
}

// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
return fn.Map(func(record tlv.Record) tlv.RecordProducer {
return &record
}, records)
}

// EncodeRecords encodes the given records into a byte slice.
func EncodeRecords(records []tlv.Record) ([]byte, error) {
var buf bytes.Buffer
if err := EncodeRecordsTo(&buf, records); err != nil {
return nil, err
}

return buf.Bytes(), nil
}

// EncodeRecordsTo encodes the given records into the given writer.
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}

return tlvStream.Encode(w)
}

// DecodeRecords decodes the given byte slice into the given records and returns
// the rest as a TLV type map.
func DecodeRecords(r io.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {

tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}

return tlvStream.DecodeWithParsedTypes(r)
}

// DecodeRecordsP2P decodes the given byte slice into the given records and
// returns the rest as a TLV type map. This function is identical to
// DecodeRecords except that the record size is capped at 65535.
func DecodeRecordsP2P(r *bytes.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {

tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}

return tlvStream.DecodeWithParsedTypesP2P(r)
}

// AssertUniqueTypes asserts that the given records have unique types.
func AssertUniqueTypes(r []tlv.Record) error {
seen := make(fn.Set[tlv.Type], len(r))
for _, record := range r {
t := record.Type()
if seen.Contains(t) {
return fmt.Errorf("duplicate record type: %d", t)
}
seen.Add(t)
}

return nil
}
54 changes: 54 additions & 0 deletions channeldb/migration/lnwire21/onion_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type FailCode uint16

// The currently defined onion failure types within this current version of the
// Lightning protocol.
//
//nolint:lll
const (
CodeNone FailCode = 0
CodeInvalidRealm = FlagBadOnion | 1
Expand All @@ -80,6 +82,7 @@ const (
CodeExpiryTooFar FailCode = 21
CodeInvalidOnionPayload = FlagPerm | 22
CodeMPPTimeout FailCode = 23
CodeInvalidBlinding = FlagBadOnion | FlagPerm | 24
)

// String returns the string representation of the failure code.
Expand Down Expand Up @@ -157,6 +160,9 @@ func (c FailCode) String() string {
case CodeMPPTimeout:
return "MPPTimeout"

case CodeInvalidBlinding:
ellemouton marked this conversation as resolved.
Show resolved Hide resolved
return "InvalidBlinding"

default:
return "<unknown>"
}
Expand Down Expand Up @@ -571,6 +577,51 @@ func (f *FailInvalidOnionKey) Error() string {
return fmt.Sprintf("InvalidOnionKey(onion_sha=%x)", f.OnionSHA256[:])
}

// FailInvalidBlinding is returned if there has been a route blinding related
// error.
type FailInvalidBlinding struct {
OnionSHA256 [sha256.Size]byte
}

// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidBlinding) Code() FailCode {
return CodeInvalidBlinding
}

// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidBlinding) Error() string {
return f.Code().String()
}

// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Decode(r io.Reader, _ uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}

// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Encode(w *bytes.Buffer, _ uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}

// NewInvalidBlinding creates new instance of FailInvalidBlinding.
func NewInvalidBlinding(onion []byte) *FailInvalidBlinding {
// The spec allows empty onion hashes for invalid blinding, so we only
// include our onion hash if it's provided.
if onion == nil {
return &FailInvalidBlinding{}
}

return &FailInvalidBlinding{OnionSHA256: sha256.Sum256(onion)}
}

// parseChannelUpdateCompatabilityMode will attempt to parse a channel updated
// encoded into an onion error payload in two ways. First, we'll try the
// compatibility oriented version wherein we'll _skip_ the length prefixing on
Expand Down Expand Up @@ -1392,6 +1443,9 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
case CodeMPPTimeout:
return &FailMPPTimeout{}, nil

case CodeInvalidBlinding:
return &FailInvalidBlinding{}, nil

default:
return nil, errors.Errorf("unknown error code: %v", code)
}
Expand Down
Loading
Loading