diff --git a/channeldb/db.go b/channeldb/db.go index 1e210d032ad..d8ba05f3d26 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration29" "github.com/lightningnetwork/lnd/channeldb/migration30" "github.com/lightningnetwork/lnd/channeldb/migration31" + "github.com/lightningnetwork/lnd/channeldb/migration32" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/invoices" @@ -286,6 +287,10 @@ var ( number: 31, migration: migration31.DeleteLastPublishedTxTLB, }, + { + number: 32, + migration: migration32.MigrateMCRouteSerialisation, + }, } // optionalVersions stores all optional migrations that are applied diff --git a/channeldb/log.go b/channeldb/log.go index a53d662cdcf..e50e5054ef0 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration24" "github.com/lightningnetwork/lnd/channeldb/migration30" "github.com/lightningnetwork/lnd/channeldb/migration31" + "github.com/lightningnetwork/lnd/channeldb/migration32" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/kvdb" ) @@ -42,5 +43,6 @@ func UseLogger(logger btclog.Logger) { migration24.UseLogger(logger) migration30.UseLogger(logger) migration31.UseLogger(logger) + migration32.UseLogger(logger) kvdb.UseLogger(logger) } diff --git a/channeldb/migration32/codec.go b/channeldb/migration32/codec.go new file mode 100644 index 00000000000..ee9dd3a0262 --- /dev/null +++ b/channeldb/migration32/codec.go @@ -0,0 +1,140 @@ +package migration32 + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/wire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" +) + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + + return nil +} diff --git a/channeldb/migration32/hop.go b/channeldb/migration32/hop.go new file mode 100644 index 00000000000..40c68eacf31 --- /dev/null +++ b/channeldb/migration32/hop.go @@ -0,0 +1,99 @@ +package migration32 + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // AmtOnionType is the type used in the onion to reference the amount to + // send to the next hop. + AmtOnionType tlv.Type = 2 + + // LockTimeOnionType is the type used in the onion to reference the CLTV + // value that should be used for the next hop's HTLC. + LockTimeOnionType tlv.Type = 4 + + // NextHopOnionType is the type used in the onion to reference the ID + // of the next hop. + NextHopOnionType tlv.Type = 6 + + // EncryptedDataOnionType is the type used to include encrypted data + // provided by the receiver in the onion for use in blinded paths. + EncryptedDataOnionType tlv.Type = 10 + + // BlindingPointOnionType is the type used to include receiver provided + // ephemeral keys in the onion that are used in blinded paths. + BlindingPointOnionType tlv.Type = 12 + + // MetadataOnionType is the type used in the onion for the payment + // metadata. + MetadataOnionType tlv.Type = 16 + + // TotalAmtMsatBlindedType is the type used in the onion for the total + // amount field that is included in the final hop for blinded payments. + TotalAmtMsatBlindedType tlv.Type = 18 +) + +// NewAmtToFwdRecord creates a tlv.Record that encodes the amount_to_forward +// (type 2) for an onion payload. +func NewAmtToFwdRecord(amt *uint64) tlv.Record { + return tlv.MakeDynamicRecord( + AmtOnionType, amt, func() uint64 { + return tlv.SizeTUint64(*amt) + }, + tlv.ETUint64, tlv.DTUint64, + ) +} + +// NewLockTimeRecord creates a tlv.Record that encodes the outgoing_cltv_value +// (type 4) for an onion payload. +func NewLockTimeRecord(lockTime *uint32) tlv.Record { + return tlv.MakeDynamicRecord( + LockTimeOnionType, lockTime, func() uint64 { + return tlv.SizeTUint32(*lockTime) + }, + tlv.ETUint32, tlv.DTUint32, + ) +} + +// NewNextHopIDRecord creates a tlv.Record that encodes the short_channel_id +// (type 6) for an onion payload. +func NewNextHopIDRecord(cid *uint64) tlv.Record { + return tlv.MakePrimitiveRecord(NextHopOnionType, cid) +} + +// NewEncryptedDataRecord creates a tlv.Record that encodes the encrypted_data +// (type 10) record for an onion payload. +func NewEncryptedDataRecord(data *[]byte) tlv.Record { + return tlv.MakePrimitiveRecord(EncryptedDataOnionType, data) +} + +// NewBlindingPointRecord creates a tlv.Record that encodes the blinding_point +// (type 12) record for an onion payload. +func NewBlindingPointRecord(point **btcec.PublicKey) tlv.Record { + return tlv.MakePrimitiveRecord(BlindingPointOnionType, point) +} + +// NewMetadataRecord creates a tlv.Record that encodes the metadata (type 10) +// for an onion payload. +func NewMetadataRecord(metadata *[]byte) tlv.Record { + return tlv.MakeDynamicRecord( + MetadataOnionType, metadata, + func() uint64 { + return uint64(len(*metadata)) + }, + tlv.EVarBytes, tlv.DVarBytes, + ) +} + +// NewTotalAmtMsatBlinded creates a tlv.Record that encodes the +// total_amount_msat for the final an onion payload within a blinded route. +func NewTotalAmtMsatBlinded(amt *uint64) tlv.Record { + return tlv.MakeDynamicRecord( + TotalAmtMsatBlindedType, amt, func() uint64 { + return tlv.SizeTUint64(*amt) + }, + tlv.ETUint64, tlv.DTUint64, + ) +} diff --git a/channeldb/migration32/log.go b/channeldb/migration32/log.go new file mode 100644 index 00000000000..98709c28ea3 --- /dev/null +++ b/channeldb/migration32/log.go @@ -0,0 +1,14 @@ +package migration32 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration32/migration.go b/channeldb/migration32/migration.go new file mode 100644 index 00000000000..14a1c11f5cf --- /dev/null +++ b/channeldb/migration32/migration.go @@ -0,0 +1,53 @@ +package migration32 + +import ( + "bytes" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +// MigrateMCRouteSerialisation reads all the mission control store entries and +// re-serializes them using a minimal route serialisation so that only the parts +// of the route that are actually required for mission control are persisted +func MigrateMCRouteSerialisation(tx kvdb.RwTx) error { + log.Infof("Migrating Mission Control store to use a more minimal " + + "encoding for routes") + + resultBucket := tx.ReadWriteBucket(resultsKey) + + // If the results bucket does not exist then there are no entries in + // the mission control store yet and so there is nothing to migrate. + if resultBucket == nil { + return nil + } + + // For each entry, read it into memory using the old encoding. Then, + // extract the more minimal route, re-encode and persist the entry. + return resultBucket.ForEach(func(k, v []byte) error { + // Read the entry using the old encoding. + resultOld, err := deserializeOldResult(k, v) + if err != nil { + return err + } + + // Convert to the new payment result format with the minimal + // route. + resultNew := convertPaymentResult(resultOld) + + // Serialise the new payment result using the new encoding. + key, resultNewBytes, err := serializeNewResult(resultNew) + if err != nil { + return err + } + + // Make sure that the derived key is the same. + if !bytes.Equal(key, k) { + return fmt.Errorf("new payment result key (%v) is "+ + "not the same as the old key (%v)", key, k) + } + + // Finally, overwrite the previous value with the new encoding. + return resultBucket.Put(k, resultNewBytes) + }) +} diff --git a/channeldb/migration32/migration_test.go b/channeldb/migration32/migration_test.go new file mode 100644 index 00000000000..724d6f453a2 --- /dev/null +++ b/channeldb/migration32/migration_test.go @@ -0,0 +1,167 @@ +package migration32 + +import ( + "encoding/hex" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + failureIndex = 8 + testPub = Vertex{2, 202, 4} + + pubkeyBytes, _ = hex.DecodeString( + "598ec453728e0ffe0ae2f5e174243cf58f2" + + "a3f2c83d2457b43036db568b11093", + ) + pubKeyY = new(btcec.FieldVal) + _ = pubKeyY.SetByteSlice(pubkeyBytes) + pubkey = btcec.NewPublicKey(new(btcec.FieldVal).SetInt(4), pubKeyY) +) + +func TestMigrateMCRouteSerialisation(t *testing.T) { + paymentResultCommon1 := paymentResultCommon{ + id: 0, + timeFwd: time.Unix(0, 1), + timeReply: time.Unix(0, 2), + success: false, + failureSourceIdx: &failureIndex, + failure: &lnwire.FailFeeInsufficient{}, + } + + customRecord := map[uint64][]byte{ + 65536: {4, 2, 2}, + } + + resultsOld := []*paymentResultOld{ + { + paymentResultCommon: paymentResultCommon1, + route: &Route{ + TotalTimeLock: 100, + TotalAmount: 400, + SourcePubKey: testPub, + Hops: []*Hop{ + // A hop with MPP, AMP + { + PubKeyBytes: testPub, + ChannelID: 100, + OutgoingTimeLock: 300, + AmtToForward: 500, + MPP: &MPP{ + paymentAddr: [32]byte{ + 4, 5, + }, + totalMsat: 900, + }, + AMP: &{ + rootShare: [32]byte{ + 0, 0, + }, + setID: [32]byte{ + 5, 5, 5, + }, + childIndex: 90, + }, + CustomRecords: customRecord, + Metadata: []byte{6, 7, 7}, + }, + // A legacy hop. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + LegacyPayload: true, + }, + // A hop with a blinding key. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + BlindingPoint: pubkey, + EncryptedData: []byte{ + 1, 2, 3, + }, + TotalAmtMsat: 600, + }, + }, + }, + }, + } + + expectedResultsNew := []*paymentResultNew{ + { + paymentResultCommon: paymentResultCommon1, + route: &mcRoute{ + sourcePubKey: testPub, + totalAmount: 400, + hops: []*mcHop{ + { + channelID: 100, + pubKeyBytes: testPub, + amtToFwd: 500, + }, + { + channelID: 800, + pubKeyBytes: testPub, + amtToFwd: 4, + }, + { + channelID: 800, + pubKeyBytes: testPub, + amtToFwd: 4, + hasBlindingPoint: true, + }, + }, + }, + }, + } + + // Prime the database with some mission control data that uses the + // old route encoding. + before := func(tx kvdb.RwTx) error { + resultBucket, err := tx.CreateTopLevelBucket(resultsKey) + if err != nil { + return err + } + + for _, result := range resultsOld { + k, v, err := serializeOldResult(result) + if err != nil { + return err + } + + if err := resultBucket.Put(k, v); err != nil { + return err + } + } + + return nil + } + + // After the migration, ensure that all the relevant info was + // maintained. + after := func(tx kvdb.RwTx) error { + m := make(map[string]interface{}) + for _, result := range expectedResultsNew { + k, v, err := serializeNewResult(result) + if err != nil { + return err + } + + m[string(k)] = string(v) + } + + return migtest.VerifyDB(tx, resultsKey, m) + } + + migtest.ApplyMigration( + t, before, after, MigrateMCRouteSerialisation, false, + ) +} diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go new file mode 100644 index 00000000000..30cadbdcf81 --- /dev/null +++ b/channeldb/migration32/mission_control_store.go @@ -0,0 +1,288 @@ +package migration32 + +import ( + "bytes" + "encoding/binary" + "math" + "time" + + "github.com/btcsuite/btcd/wire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" +) + +const ( + // unknownFailureSourceIdx is the database encoding of an unknown error + // source. + unknownFailureSourceIdx = -1 +) + +var ( + // resultsKey is the fixed key under which the attempt results are + // stored. + resultsKey = []byte("missioncontrol-results") + + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// paymentResultCommon holds the fields that are shared by the old and new +// payment result encoding. +type paymentResultCommon struct { + id uint64 + timeFwd, timeReply time.Time + success bool + failureSourceIdx *int + failure lnwire.FailureMessage +} + +// paymentResultOld is the information that becomes available when a payment +// attempt completes. +type paymentResultOld struct { + paymentResultCommon + route *Route +} + +// deserializeOldResult deserializes a payment result using the old encoding. +func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { + // Parse payment id. + result := paymentResultOld{ + paymentResultCommon: paymentResultCommon{ + id: byteOrder.Uint64(k[8:]), + }, + } + + r := bytes.NewReader(v) + + // Read timestamps, success status and failure source index. + var ( + timeFwd, timeReply uint64 + dbFailureSourceIdx int32 + ) + + err := ReadElements( + r, &timeFwd, &timeReply, &result.success, &dbFailureSourceIdx, + ) + if err != nil { + return nil, err + } + + // Convert time stamps to local time zone for consistent logging. + result.timeFwd = time.Unix(0, int64(timeFwd)).Local() + result.timeReply = time.Unix(0, int64(timeReply)).Local() + + // Convert from unknown index magic number to nil value. + if dbFailureSourceIdx != unknownFailureSourceIdx { + failureSourceIdx := int(dbFailureSourceIdx) + result.failureSourceIdx = &failureSourceIdx + } + + // Read route. + route, err := DeserializeRoute(r) + if err != nil { + return nil, err + } + result.route = &route + + // Read failure. + failureBytes, err := wire.ReadVarBytes(r, 0, math.MaxUint16, "failure") + if err != nil { + return nil, err + } + if len(failureBytes) > 0 { + result.failure, err = lnwire.DecodeFailureMessage( + bytes.NewReader(failureBytes), 0, + ) + if err != nil { + return nil, err + } + } + + return &result, nil +} + +// convertPaymentResult converts a paymentResultOld to a paymentResultNew. +func convertPaymentResult(old *paymentResultOld) *paymentResultNew { + return &paymentResultNew{ + paymentResultCommon: old.paymentResultCommon, + route: extractMCRoute(old.route), + } +} + +// paymentResultNew is the information that becomes available when a payment +// attempt completes. +type paymentResultNew struct { + paymentResultCommon + route *mcRoute +} + +// extractMCRoute extracts the fields required by MC from the Route struct to +// create the more minima mcRoute struct. +func extractMCRoute(route *Route) *mcRoute { + return &mcRoute{ + sourcePubKey: route.SourcePubKey, + totalAmount: route.TotalAmount, + hops: extractMCHops(route.Hops), + } +} + +// extractMCHops extracts the Hop fields that MC actually uses from a slice of +// Hops. +func extractMCHops(hops []*Hop) []*mcHop { + mcHops := make([]*mcHop, len(hops)) + for i, hop := range hops { + mcHops[i] = extractMCHop(hop) + } + + return mcHops +} + +// extractMCHop extracts the Hop fields that MC actually uses from a Hop. +func extractMCHop(hop *Hop) *mcHop { + return &mcHop{ + channelID: hop.ChannelID, + pubKeyBytes: hop.PubKeyBytes, + amtToFwd: hop.AmtToForward, + hasBlindingPoint: hop.BlindingPoint != nil, + } +} + +// mcRoute holds the bare minimum info about a payment attempt route that MC +// requires. +type mcRoute struct { + sourcePubKey Vertex + totalAmount lnwire.MilliSatoshi + hops []*mcHop +} + +// mcHop holds the bare minimum info about a payment attempt route hop that MC +// requires. +type mcHop struct { + channelID uint64 + pubKeyBytes Vertex + amtToFwd lnwire.MilliSatoshi + hasBlindingPoint bool +} + +// serializeOldResult serializes a payment result and returns a key and value +// byte slice to insert into the bucket. +func serializeOldResult(rp *paymentResultOld) ([]byte, []byte, error) { + // Write timestamps, success status, failure source index and route. + var b bytes.Buffer + var dbFailureSourceIdx int32 + if rp.failureSourceIdx == nil { + dbFailureSourceIdx = unknownFailureSourceIdx + } else { + dbFailureSourceIdx = int32(*rp.failureSourceIdx) + } + err := WriteElements( + &b, + uint64(rp.timeFwd.UnixNano()), + uint64(rp.timeReply.UnixNano()), + rp.success, dbFailureSourceIdx, + ) + if err != nil { + return nil, nil, err + } + + if err := SerializeRoute(&b, *rp.route); err != nil { + return nil, nil, err + } + + // Write failure. If there is no failure message, write an empty + // byte slice. + var failureBytes bytes.Buffer + if rp.failure != nil { + err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0) + if err != nil { + return nil, nil, err + } + } + err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes()) + if err != nil { + return nil, nil, err + } + // Compose key that identifies this result. + key := getResultKeyOld(rp) + + return key, b.Bytes(), nil +} + +// getResultKeyOld returns a byte slice representing a unique key for this +// payment result. +func getResultKeyOld(rp *paymentResultOld) []byte { + var keyBytes [8 + 8 + 33]byte + + // Identify records by a combination of time, payment id and sender pub + // key. This allows importing mission control data from an external + // source without key collisions and keeps the records sorted + // chronologically. + byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) + byteOrder.PutUint64(keyBytes[8:], rp.id) + copy(keyBytes[16:], rp.route.SourcePubKey[:]) + + return keyBytes[:] +} + +// serializeNewResult serializes a payment result and returns a key and value +// byte slice to insert into the bucket. +func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) { + // Write timestamps, success status, failure source index and route. + var b bytes.Buffer + + var dbFailureSourceIdx int32 + if rp.failureSourceIdx == nil { + dbFailureSourceIdx = unknownFailureSourceIdx + } else { + dbFailureSourceIdx = int32(*rp.failureSourceIdx) + } + + err := WriteElements( + &b, + uint64(rp.timeFwd.UnixNano()), + uint64(rp.timeReply.UnixNano()), + rp.success, dbFailureSourceIdx, + ) + if err != nil { + return nil, nil, err + } + + if err := SerializeMCRoute(&b, rp.route); err != nil { + return nil, nil, err + } + + // Write failure. If there is no failure message, write an empty + // byte slice. + var failureBytes bytes.Buffer + if rp.failure != nil { + err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0) + if err != nil { + return nil, nil, err + } + } + err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes()) + if err != nil { + return nil, nil, err + } + + // Compose key that identifies this result. + key := getResultKeyNew(rp) + + return key, b.Bytes(), nil +} + +// getResultKeyNew returns a byte slice representing a unique key for this +// payment result. +func getResultKeyNew(rp *paymentResultNew) []byte { + var keyBytes [8 + 8 + 33]byte + + // Identify records by a combination of time, payment id and sender pub + // key. This allows importing mission control data from an external + // source without key collisions and keeps the records sorted + // chronologically. + byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) + byteOrder.PutUint64(keyBytes[8:], rp.id) + copy(keyBytes[16:], rp.route.sourcePubKey[:]) + + return keyBytes[:] +} diff --git a/channeldb/migration32/route.go b/channeldb/migration32/route.go new file mode 100644 index 00000000000..9597ff24028 --- /dev/null +++ b/channeldb/migration32/route.go @@ -0,0 +1,611 @@ +package migration32 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/wire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // MPPOnionType is the type used in the onion to reference the MPP + // fields: total_amt and payment_addr. + MPPOnionType tlv.Type = 8 + + // AMPOnionType is the type used in the onion to reference the AMP + // fields: root_share, set_id, and child_index. + AMPOnionType tlv.Type = 14 +) + +// VertexSize is the size of the array to store a vertex. +const VertexSize = 33 + +// Vertex is a simple alias for the serialization of a compressed Bitcoin +// public key. +type Vertex [VertexSize]byte + +// Route represents a path through the channel graph which runs over one or +// more channels in succession. This struct carries all the information +// required to craft the Sphinx onion packet, and send the payment along the +// first hop in the path. A route is only selected as valid if all the channels +// have sufficient capacity to carry the initial payment amount after fees are +// accounted for. +type Route struct { + // TotalTimeLock is the cumulative (final) time lock across the entire + // route. This is the CLTV value that should be extended to the first + // hop in the route. All other hops will decrement the time-lock as + // advertised, leaving enough time for all hops to wait for or present + // the payment preimage to complete the payment. + TotalTimeLock uint32 + + // TotalAmount is the total amount of funds required to complete a + // payment over this route. This value includes the cumulative fees at + // each hop. As a result, the HTLC extended to the first-hop in the + // route will need to have at least this many satoshis, otherwise the + // route will fail at an intermediate node due to an insufficient + // amount of fees. + TotalAmount lnwire.MilliSatoshi + + // SourcePubKey is the pubkey of the node where this route originates + // from. + SourcePubKey Vertex + + // Hops contains details concerning the specific forwarding details at + // each hop. + Hops []*Hop +} + +// Hop represents an intermediate or final node of the route. This naming +// is in line with the definition given in BOLT #4: Onion Routing Protocol. +// The struct houses the channel along which this hop can be reached and +// the values necessary to create the HTLC that needs to be sent to the +// next hop. It is also used to encode the per-hop payload included within +// the Sphinx packet. +type Hop struct { + // PubKeyBytes is the raw bytes of the public key of the target node. + PubKeyBytes Vertex + + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // OutgoingTimeLock is the timelock value that should be used when + // crafting the _outgoing_ HTLC from this hop. + OutgoingTimeLock uint32 + + // AmtToForward is the amount that this hop will forward to the next + // hop. This value is less than the value that the incoming HTLC + // carries as a fee will be subtracted by the hop. + AmtToForward lnwire.MilliSatoshi + + // MPP encapsulates the data required for option_mpp. This field should + // only be set for the final hop. + MPP *MPP + + // AMP encapsulates the data required for option_amp. This field should + // only be set for the final hop. + AMP *AMP + + // CustomRecords if non-nil are a set of additional TLV records that + // should be included in the forwarding instructions for this node. + CustomRecords CustomSet + + // LegacyPayload if true, then this signals that this node doesn't + // understand the new TLV payload, so we must instead use the legacy + // payload. + // + // NOTE: we should no longer ever create a Hop with Legacy set to true. + // The only reason we are keeping this member is that it could be the + // case that we have serialised hops persisted to disk where + // LegacyPayload is true. + LegacyPayload bool + + // Metadata is additional data that is sent along with the payment to + // the payee. + Metadata []byte + + // EncryptedData is an encrypted data blob includes for hops that are + // part of a blinded route. + EncryptedData []byte + + // BlindingPoint is an ephemeral public key used by introduction nodes + // in blinded routes to unblind their portion of the route and pass on + // the next ephemeral key to the next blinded node to do the same. + BlindingPoint *btcec.PublicKey + + // TotalAmtMsat is the total amount for a blinded payment, potentially + // spread over more than one HTLC. This field should only be set for + // the final hop in a blinded path. + TotalAmtMsat lnwire.MilliSatoshi +} + +// MPP is a record that encodes the fields necessary for multi-path payments. +type MPP struct { + // paymentAddr is a random, receiver-generated value used to avoid + // collisions with concurrent payers. + paymentAddr [32]byte + + // totalMsat is the total value of the payment, potentially spread + // across more than one HTLC. + totalMsat lnwire.MilliSatoshi +} + +// Record returns a tlv.Record that can be used to encode or decode this record. +func (r *MPP) Record() tlv.Record { + // Fixed-size, 32 byte payment address followed by truncated 64-bit + // total msat. + size := func() uint64 { + return 32 + tlv.SizeTUint64(uint64(r.totalMsat)) + } + + return tlv.MakeDynamicRecord( + MPPOnionType, r, size, MPPEncoder, MPPDecoder, + ) +} + +const ( + // minMPPLength is the minimum length of a serialized MPP TLV record, + // which occurs when the truncated encoding of total_amt_msat takes 0 + // bytes, leaving only the payment_addr. + minMPPLength = 32 + + // maxMPPLength is the maximum length of a serialized MPP TLV record, + // which occurs when the truncated encoding of total_amt_msat takes 8 + // bytes. + maxMPPLength = 40 +) + +// MPPEncoder writes the MPP record to the provided io.Writer. +func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MPP); ok { + err := tlv.EBytes32(w, &v.paymentAddr, buf) + if err != nil { + return err + } + + return tlv.ETUint64T(w, uint64(v.totalMsat), buf) + } + + return tlv.NewTypeForEncodingErr(val, "MPP") +} + +// MPPDecoder reads the MPP record to the provided io.Reader. +func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength { + if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil { + return err + } + + var total uint64 + if err := tlv.DTUint64(r, &total, buf, l-32); err != nil { + return err + } + v.totalMsat = lnwire.MilliSatoshi(total) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength) +} + +// AMP is a record that encodes the fields necessary for atomic multi-path +// payments. +type AMP struct { + rootShare [32]byte + setID [32]byte + childIndex uint32 +} + +// AMPEncoder writes the AMP record to the provided io.Writer. +func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*AMP); ok { + if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil { + return err + } + + if err := tlv.EBytes32(w, &v.setID, buf); err != nil { + return err + } + + return tlv.ETUint32T(w, v.childIndex, buf) + } + + return tlv.NewTypeForEncodingErr(val, "AMP") +} + +const ( + // minAMPLength is the minimum length of a serialized AMP TLV record, + // which occurs when the truncated encoding of child_index takes 0 + // bytes, leaving only the root_share and set_id. + minAMPLength = 64 + + // maxAMPLength is the maximum length of a serialized AMP TLV record, + // which occurs when the truncated encoding of a child_index takes 2 + // bytes. + maxAMPLength = 68 +) + +// AMPDecoder reads the AMP record from the provided io.Reader. +func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength { + if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil { + return err + } + + if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil { + return err + } + + return tlv.DTUint32(r, &v.childIndex, buf, l-minAMPLength) + } + + return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength) +} + +// Record returns a tlv.Record that can be used to encode or decode this record. +func (a *AMP) Record() tlv.Record { + return tlv.MakeDynamicRecord( + AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder, + ) +} + +// PayloadSize returns the size this record takes up in encoded form. +func (a *AMP) PayloadSize() uint64 { + return 32 + 32 + tlv.SizeTUint32(a.childIndex) +} + +const ( + // CustomTypeStart is the start of the custom tlv type range as defined + // in BOLT 01. + CustomTypeStart = 65536 +) + +// CustomSet stores a set of custom key/value pairs. +type CustomSet map[uint64][]byte + +// Validate checks that all custom records are in the custom type range. +func (c CustomSet) Validate() error { + for key := range c { + if key < CustomTypeStart { + return fmt.Errorf("no custom records with types "+ + "below %v allowed", CustomTypeStart) + } + } + + return nil +} + +// SerializeRoute serializes a route. +func SerializeRoute(w io.Writer, r Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHop(w, h); err != nil { + return err + } + } + + return nil +} + +func serializeHop(w io.Writer, h *Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], + h.ChannelID, + h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil { + return err + } + + // For legacy payloads, we don't need to write any TLV records, so + // we'll write a zero indicating the our serialized TLV map has no + // records. + if h.LegacyPayload { + return WriteElements(w, uint32(0)) + } + + // Gather all non-primitive TLV records so that they can be serialized + // as a single blob. + // + // TODO(conner): add migration to unify all fields in a single TLV + // blobs. The split approach will cause headaches down the road as more + // fields are added, which we can avoid by having a single TLV stream + // for all payload fields. + var records []tlv.Record + if h.MPP != nil { + records = append(records, h.MPP.Record()) + } + + // Add blinding point and encrypted data if present. + if h.EncryptedData != nil { + records = append(records, NewEncryptedDataRecord( + &h.EncryptedData, + )) + } + + if h.BlindingPoint != nil { + records = append(records, NewBlindingPointRecord( + &h.BlindingPoint, + )) + } + + if h.AMP != nil { + records = append(records, h.AMP.Record()) + } + + if h.Metadata != nil { + records = append(records, NewMetadataRecord(&h.Metadata)) + } + + if h.TotalAmtMsat != 0 { + totalMsatInt := uint64(h.TotalAmtMsat) + records = append( + records, NewTotalAmtMsatBlinded(&totalMsatInt), + ) + } + + // Final sanity check to absolutely rule out custom records that are not + // custom and write into the standard range. + if err := h.CustomRecords.Validate(); err != nil { + return err + } + + // Convert custom records to tlv and add to the record list. + // MapToRecords sorts the list, so adding it here will keep the list + // canonical. + tlvRecords := tlv.MapToRecords(h.CustomRecords) + records = append(records, tlvRecords...) + + // Otherwise, we'll transform our slice of records into a map of the + // raw bytes, then serialize them in-line with a length (number of + // elements) prefix. + mapRecords, err := tlv.RecordsToMap(records) + if err != nil { + return err + } + + numRecords := uint32(len(mapRecords)) + if err := WriteElements(w, numRecords); err != nil { + return err + } + + for recordType, rawBytes := range mapRecords { + if err := WriteElements(w, recordType); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil { + return err + } + } + + return nil +} + +// DeserializeRoute deserializes a route. +func DeserializeRoute(r io.Reader) (Route, error) { + rt := Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHop(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} + +// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need +// to read/write a TLV stream larger than this. +const maxOnionPayloadSize = 1300 + +func deserializeHop(r io.Reader) (*Hop, error) { + h := &Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + // TODO(roasbeef): change field to allow LegacyPayload false to be the + // legacy default? + err := binary.Read(r, byteOrder, &h.LegacyPayload) + if err != nil { + return nil, err + } + + var numElements uint32 + if err := ReadElements(r, &numElements); err != nil { + return nil, err + } + + // If there're no elements, then we can return early. + if numElements == 0 { + return h, nil + } + + tlvMap := make(map[uint64][]byte) + for i := uint32(0); i < numElements; i++ { + var tlvType uint64 + if err := ReadElements(r, &tlvType); err != nil { + return nil, err + } + + rawRecordBytes, err := wire.ReadVarBytes( + r, 0, maxOnionPayloadSize, "tlv", + ) + if err != nil { + return nil, err + } + + tlvMap[tlvType] = rawRecordBytes + } + + // If the MPP type is present, remove it from the generic TLV map and + // parse it back into a proper MPP struct. + // + // TODO(conner): add migration to unify all fields in a single TLV + // blobs. The split approach will cause headaches down the road as more + // fields are added, which we can avoid by having a single TLV stream + // for all payload fields. + mppType := uint64(MPPOnionType) + if mppBytes, ok := tlvMap[mppType]; ok { + delete(tlvMap, mppType) + + var ( + mpp = &MPP{} + mppRec = mpp.Record() + r = bytes.NewReader(mppBytes) + ) + err := mppRec.Decode(r, uint64(len(mppBytes))) + if err != nil { + return nil, err + } + h.MPP = mpp + } + + // If encrypted data or blinding key are present, remove them from + // the TLV map and parse into proper types. + encryptedDataType := uint64(EncryptedDataOnionType) + if data, ok := tlvMap[encryptedDataType]; ok { + delete(tlvMap, encryptedDataType) + h.EncryptedData = data + } + + blindingType := uint64(BlindingPointOnionType) + if blindingPoint, ok := tlvMap[blindingType]; ok { + delete(tlvMap, blindingType) + + h.BlindingPoint, err = btcec.ParsePubKey(blindingPoint) + if err != nil { + return nil, fmt.Errorf("invalid blinding point: %w", + err) + } + } + + ampType := uint64(AMPOnionType) + if ampBytes, ok := tlvMap[ampType]; ok { + delete(tlvMap, ampType) + + var ( + amp = &{} + ampRec = amp.Record() + r = bytes.NewReader(ampBytes) + ) + err := ampRec.Decode(r, uint64(len(ampBytes))) + if err != nil { + return nil, err + } + h.AMP = amp + } + + // If the metadata type is present, remove it from the tlv map and + // populate directly on the hop. + metadataType := uint64(MetadataOnionType) + if metadata, ok := tlvMap[metadataType]; ok { + delete(tlvMap, metadataType) + + h.Metadata = metadata + } + + totalAmtMsatType := uint64(TotalAmtMsatBlindedType) + if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok { + delete(tlvMap, totalAmtMsatType) + + var ( + totalAmtMsatInt uint64 + buf [8]byte + ) + if err := tlv.DTUint64( + bytes.NewReader(totalAmtMsat), + &totalAmtMsatInt, + &buf, + uint64(len(totalAmtMsat)), + ); err != nil { + return nil, err + } + + h.TotalAmtMsat = lnwire.MilliSatoshi(totalAmtMsatInt) + } + + h.CustomRecords = tlvMap + + return h, nil +} + +// SerializeMCRoute serializes an mcRoute and writes the bytes to the given +// io.Writer. +func SerializeMCRoute(w io.Writer, r *mcRoute) error { + if err := WriteElements( + w, r.totalAmount, r.sourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.hops))); err != nil { + return err + } + + for _, h := range r.hops { + if err := serializeNewHop(w, h); err != nil { + return err + } + } + + return nil +} + +func serializeNewHop(w io.Writer, h *mcHop) error { + return WriteElements(w, + h.pubKeyBytes[:], + h.channelID, + h.amtToFwd, + h.hasBlindingPoint, + ) +} diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index 1dceace7c08..cf358a2533c 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -199,7 +199,7 @@ type MissionControlPairSnapshot struct { type paymentResult struct { id uint64 timeFwd, timeReply time.Time - route *route.Route + route *mcRoute success bool failureSourceIdx *int failure lnwire.FailureMessage @@ -421,7 +421,7 @@ func (m *MissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route, id: paymentID, failureSourceIdx: failureSourceIdx, failure: failure, - route: rt, + route: extractMCRoute(rt), } return m.processPaymentResult(result) @@ -439,7 +439,7 @@ func (m *MissionControl) ReportPaymentSuccess(paymentID uint64, timeReply: timestamp, id: paymentID, success: true, - route: rt, + route: extractMCRoute(rt), } _, err := m.processPaymentResult(result) diff --git a/routing/missioncontrol_store.go b/routing/missioncontrol_store.go index e149e854589..0f0dd9cf906 100644 --- a/routing/missioncontrol_store.go +++ b/routing/missioncontrol_store.go @@ -5,6 +5,7 @@ import ( "container/list" "encoding/binary" "fmt" + "io" "math" "sync" "time" @@ -187,7 +188,7 @@ func serializeResult(rp *paymentResult) ([]byte, []byte, error) { return nil, nil, err } - if err := channeldb.SerializeRoute(&b, *rp.route); err != nil { + if err := serializeRoute(&b, rp.route); err != nil { return nil, nil, err } @@ -211,6 +212,86 @@ func serializeResult(rp *paymentResult) ([]byte, []byte, error) { return key, b.Bytes(), nil } +func deserializeRoute(r io.Reader) (*mcRoute, error) { + var rt mcRoute + if err := channeldb.ReadElements( + r, &rt.totalAmount, + ); err != nil { + return nil, err + } + + var pub []byte + if err := channeldb.ReadElements(r, &pub); err != nil { + return nil, err + } + copy(rt.sourcePubKey[:], pub) + + var numHops uint32 + if err := channeldb.ReadElements(r, &numHops); err != nil { + return nil, err + } + + var hops []*mcHop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHop(r) + if err != nil { + return nil, err + } + hops = append(hops, hop) + } + rt.hops = hops + + return &rt, nil +} + +func deserializeHop(r io.Reader) (*mcHop, error) { + var h mcHop + + var pub []byte + if err := channeldb.ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.pubKeyBytes[:], pub) + + if err := channeldb.ReadElements(r, + &h.channelID, &h.amtToFwd, &h.hasBlindingPoint, + ); err != nil { + return nil, err + } + + return &h, nil +} + +// serializeRoute serializes a route. +func serializeRoute(w io.Writer, r *mcRoute) error { + if err := channeldb.WriteElements(w, + r.totalAmount, r.sourcePubKey[:], + ); err != nil { + return err + } + + if err := channeldb.WriteElements(w, uint32(len(r.hops))); err != nil { + return err + } + + for _, h := range r.hops { + if err := serializeHop(w, h); err != nil { + return err + } + } + + return nil +} + +func serializeHop(w io.Writer, h *mcHop) error { + return channeldb.WriteElements(w, + h.pubKeyBytes[:], + h.channelID, + h.amtToFwd, + h.hasBlindingPoint, + ) +} + // deserializeResult deserializes a payment result. func deserializeResult(k, v []byte) (*paymentResult, error) { // Parse payment id. @@ -244,11 +325,11 @@ func deserializeResult(k, v []byte) (*paymentResult, error) { } // Read route. - route, err := channeldb.DeserializeRoute(r) + route, err := deserializeRoute(r) if err != nil { return nil, err } - result.route = &route + result.route = route // Read failure. failureBytes, err := wire.ReadVarBytes( @@ -490,7 +571,7 @@ func getResultKey(rp *paymentResult) []byte { // chronologically. byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) byteOrder.PutUint64(keyBytes[8:], rp.id) - copy(keyBytes[16:], rp.route.SourcePubKey[:]) + copy(keyBytes[16:], rp.route.sourcePubKey[:]) return keyBytes[:] } diff --git a/routing/missioncontrol_store_test.go b/routing/missioncontrol_store_test.go index 34b925a3e17..ff9b270260b 100644 --- a/routing/missioncontrol_store_test.go +++ b/routing/missioncontrol_store_test.go @@ -18,12 +18,11 @@ const testMaxRecords = 2 var ( // mcStoreTestRoute is a test route for the mission control store tests. - mcStoreTestRoute = route.Route{ - SourcePubKey: route.Vertex{1}, - Hops: []*route.Hop{ + mcStoreTestRoute = mcRoute{ + sourcePubKey: route.Vertex{1}, + hops: []*mcHop{ { - PubKeyBytes: route.Vertex{2}, - LegacyPayload: true, + pubKeyBytes: route.Vertex{2}, }, }, } diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 4118286e64b..9aeb5cb607e 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -76,7 +76,7 @@ type interpretedResult struct { // interpretResult interprets a payment outcome and returns an object that // contains information required to update mission control. -func interpretResult(rt *route.Route, success bool, failureSrcIdx *int, +func interpretResult(rt *mcRoute, success bool, failureSrcIdx *int, failure lnwire.FailureMessage) *interpretedResult { i := &interpretedResult{ @@ -92,15 +92,14 @@ func interpretResult(rt *route.Route, success bool, failureSrcIdx *int, } // processSuccess processes a successful payment attempt. -func (i *interpretedResult) processSuccess(route *route.Route) { +func (i *interpretedResult) processSuccess(route *mcRoute) { // For successes, all nodes must have acted in the right way. Therefore // we mark all of them with a success result. - i.successPairRange(route, 0, len(route.Hops)-1) + i.successPairRange(route, 0, len(route.hops)-1) } // processFail processes a failed payment attempt. -func (i *interpretedResult) processFail( - rt *route.Route, errSourceIdx *int, +func (i *interpretedResult) processFail(rt *mcRoute, errSourceIdx *int, failure lnwire.FailureMessage) { if errSourceIdx == nil { @@ -125,10 +124,8 @@ func (i *interpretedResult) processFail( i.processPaymentOutcomeSelf(rt, failure) // A failure from the final hop was received. - case len(rt.Hops): - i.processPaymentOutcomeFinal( - rt, failure, - ) + case len(rt.hops): + i.processPaymentOutcomeFinal(rt, failure) // An intermediate hop failed. Interpret the outcome, update reputation // and try again. @@ -144,7 +141,7 @@ func (i *interpretedResult) processFail( // node. This indicates that the introduction node is not obeying the route // blinding specification, as we expect all errors from the introduction node // to be source from it. -func (i *interpretedResult) processPaymentOutcomeBadIntro(route *route.Route, +func (i *interpretedResult) processPaymentOutcomeBadIntro(route *mcRoute, introIdx, errSourceIdx int) { // We fail the introduction node for not obeying the specification. @@ -161,14 +158,14 @@ func (i *interpretedResult) processPaymentOutcomeBadIntro(route *route.Route, // a final failure reason because the recipient can't process the // payment (independent of the introduction failing to convert the // error, we can't complete the payment if the last hop fails). - if errSourceIdx == len(route.Hops) { + if errSourceIdx == len(route.hops) { i.finalFailureReason = &reasonError } } // processPaymentOutcomeSelf handles failures sent by ourselves. -func (i *interpretedResult) processPaymentOutcomeSelf( - rt *route.Route, failure lnwire.FailureMessage) { +func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute, + failure lnwire.FailureMessage) { switch failure.(type) { @@ -181,7 +178,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf( i.failNode(rt, 1) // If this was a payment to a direct peer, we can stop trying. - if len(rt.Hops) == 1 { + if len(rt.hops) == 1 { i.finalFailureReason = &reasonError } @@ -191,15 +188,15 @@ func (i *interpretedResult) processPaymentOutcomeSelf( // available in the link has been updated. default: log.Warnf("Routing failure for local channel %v occurred", - rt.Hops[0].ChannelID) + rt.hops[0].channelID) } } // processPaymentOutcomeFinal handles failures sent by the final hop. -func (i *interpretedResult) processPaymentOutcomeFinal( - route *route.Route, failure lnwire.FailureMessage) { +func (i *interpretedResult) processPaymentOutcomeFinal(route *mcRoute, + failure lnwire.FailureMessage) { - n := len(route.Hops) + n := len(route.hops) failNode := func() { i.failNode(route, n) @@ -292,9 +289,10 @@ func (i *interpretedResult) processPaymentOutcomeFinal( // processPaymentOutcomeIntermediate handles failures sent by an intermediate // hop. -func (i *interpretedResult) processPaymentOutcomeIntermediate( - route *route.Route, errorSourceIdx int, - failure lnwire.FailureMessage) { +// +//nolint:funlen +func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute, + errorSourceIdx int, failure lnwire.FailureMessage) { reportOutgoing := func() { i.failPair( @@ -398,8 +396,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate( // Set the node pair for which a channel update may be out of // date. The second chance logic uses the policyFailure field. i.policyFailure = &DirectedNodePair{ - From: route.Hops[errorSourceIdx-1].PubKeyBytes, - To: route.Hops[errorSourceIdx].PubKeyBytes, + From: route.hops[errorSourceIdx-1].pubKeyBytes, + To: route.hops[errorSourceIdx].pubKeyBytes, } reportOutgoing() @@ -427,8 +425,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate( // Set the node pair for which a channel update may be out of // date. The second chance logic uses the policyFailure field. i.policyFailure = &DirectedNodePair{ - From: route.Hops[errorSourceIdx-1].PubKeyBytes, - To: route.Hops[errorSourceIdx].PubKeyBytes, + From: route.hops[errorSourceIdx-1].pubKeyBytes, + To: route.hops[errorSourceIdx].pubKeyBytes, } // We report incoming channel. If a second pair is granted in @@ -502,16 +500,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate( // Note that if LND is extended to support multiple blinded // routes, this will terminate the payment without re-trying // the other routes. - if introIdx == len(route.Hops)-1 { + if introIdx == len(route.hops)-1 { i.finalFailureReason = &reasonError } else { // If there are other hops between the recipient and // introduction node, then we just penalize the last // hop in the blinded route to minimize the storage of // results for ephemeral keys. - i.failPairBalance( - route, len(route.Hops)-1, - ) + i.failPairBalance(route, len(route.hops)-1) } // In all other cases, we penalize the reporting node. These are all @@ -525,9 +521,9 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate( // route, using the same indexing in the route that we use for errorSourceIdx // (i.e., that we consider our own node to be at index zero). A boolean is // returned to indicate whether the route contains a blinded portion at all. -func introductionPointIndex(route *route.Route) (int, bool) { - for i, hop := range route.Hops { - if hop.BlindingPoint != nil { +func introductionPointIndex(route *mcRoute) (int, bool) { + for i, hop := range route.hops { + if hop.hasBlindingPoint { return i + 1, true } } @@ -537,8 +533,8 @@ func introductionPointIndex(route *route.Route) (int, bool) { // processPaymentOutcomeUnknown processes a payment outcome for which no failure // message or source is available. -func (i *interpretedResult) processPaymentOutcomeUnknown(route *route.Route) { - n := len(route.Hops) +func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) { + n := len(route.hops) // If this is a direct payment, the destination must be at fault. if n == 1 { @@ -553,12 +549,51 @@ func (i *interpretedResult) processPaymentOutcomeUnknown(route *route.Route) { i.failPairRange(route, 0, n-1) } +func extractMCRoute(route *route.Route) *mcRoute { + return &mcRoute{ + sourcePubKey: route.SourcePubKey, + totalAmount: route.TotalAmount, + hops: extractMCHops(route.Hops), + } +} + +func extractMCHops(hops []*route.Hop) []*mcHop { + mcHops := make([]*mcHop, len(hops)) + for i, hop := range hops { + mcHops[i] = extractMCHop(hop) + } + + return mcHops +} + +func extractMCHop(hop *route.Hop) *mcHop { + return &mcHop{ + channelID: hop.ChannelID, + pubKeyBytes: hop.PubKeyBytes, + amtToFwd: hop.AmtToForward, + hasBlindingPoint: hop.BlindingPoint != nil, + } +} + +type mcRoute struct { + sourcePubKey route.Vertex + totalAmount lnwire.MilliSatoshi + hops []*mcHop +} + +type mcHop struct { + channelID uint64 + pubKeyBytes route.Vertex + amtToFwd lnwire.MilliSatoshi + hasBlindingPoint bool +} + // failNode marks the node indicated by idx in the route as failed. It also // marks the incoming and outgoing channels of the node as failed. This function // intentionally panics when the self node is failed. -func (i *interpretedResult) failNode(rt *route.Route, idx int) { +func (i *interpretedResult) failNode(rt *mcRoute, idx int) { // Mark the node as failing. - i.nodeFailure = &rt.Hops[idx-1].PubKeyBytes + i.nodeFailure = &rt.hops[idx-1].pubKeyBytes // Mark the incoming connection as failed for the node. We intent to // penalize as much as we can for a node level failure, including future @@ -574,7 +609,7 @@ func (i *interpretedResult) failNode(rt *route.Route, idx int) { // If not the ultimate node, mark the outgoing connection as failed for // the node. - if idx < len(rt.Hops) { + if idx < len(rt.hops) { outgoingChannelIdx := idx outPair, _ := getPair(rt, outgoingChannelIdx) i.pairResults[outPair] = failPairResult(0) @@ -584,18 +619,14 @@ func (i *interpretedResult) failNode(rt *route.Route, idx int) { // failPairRange marks the node pairs from node fromIdx to node toIdx as failed // in both direction. -func (i *interpretedResult) failPairRange( - rt *route.Route, fromIdx, toIdx int) { - +func (i *interpretedResult) failPairRange(rt *mcRoute, fromIdx, toIdx int) { for idx := fromIdx; idx <= toIdx; idx++ { i.failPair(rt, idx) } } // failPair marks a pair as failed in both directions. -func (i *interpretedResult) failPair( - rt *route.Route, idx int) { - +func (i *interpretedResult) failPair(rt *mcRoute, idx int) { pair, _ := getPair(rt, idx) // Report pair in both directions without a minimum penalization amount. @@ -604,9 +635,7 @@ func (i *interpretedResult) failPair( } // failPairBalance marks a pair as failed with a minimum penalization amount. -func (i *interpretedResult) failPairBalance( - rt *route.Route, channelIdx int) { - +func (i *interpretedResult) failPairBalance(rt *mcRoute, channelIdx int) { pair, amt := getPair(rt, channelIdx) i.pairResults[pair] = failPairResult(amt) @@ -614,9 +643,7 @@ func (i *interpretedResult) failPairBalance( // successPairRange marks the node pairs from node fromIdx to node toIdx as // succeeded. -func (i *interpretedResult) successPairRange( - rt *route.Route, fromIdx, toIdx int) { - +func (i *interpretedResult) successPairRange(rt *mcRoute, fromIdx, toIdx int) { for idx := fromIdx; idx <= toIdx; idx++ { pair, amt := getPair(rt, idx) @@ -626,21 +653,21 @@ func (i *interpretedResult) successPairRange( // getPair returns a node pair from the route and the amount passed between that // pair. -func getPair(rt *route.Route, channelIdx int) (DirectedNodePair, +func getPair(rt *mcRoute, channelIdx int) (DirectedNodePair, lnwire.MilliSatoshi) { - nodeTo := rt.Hops[channelIdx].PubKeyBytes + nodeTo := rt.hops[channelIdx].pubKeyBytes var ( nodeFrom route.Vertex amt lnwire.MilliSatoshi ) if channelIdx == 0 { - nodeFrom = rt.SourcePubKey - amt = rt.TotalAmount + nodeFrom = rt.sourcePubKey + amt = rt.totalAmount } else { - nodeFrom = rt.Hops[channelIdx-1].PubKeyBytes - amt = rt.Hops[channelIdx-1].AmtToForward + nodeFrom = rt.hops[channelIdx-1].pubKeyBytes + amt = rt.hops[channelIdx-1].amtToFwd } pair := NewDirectedNodePair(nodeFrom, nodeTo) diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index bf7d6d3eddd..8d452a6e8c3 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -4,7 +4,6 @@ import ( "reflect" "testing" - "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -15,109 +14,105 @@ var ( {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, } - // blindingPoint provides a non-nil blinding point (value is never - // used). - blindingPoint = &btcec.PublicKey{} - - routeOneHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, + routeOneHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, }, } - routeTwoHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, - {PubKeyBytes: hops[2], AmtToForward: 97}, + routeTwoHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, + {pubKeyBytes: hops[2], amtToFwd: 97}, }, } - routeThreeHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, - {PubKeyBytes: hops[2], AmtToForward: 97}, - {PubKeyBytes: hops[3], AmtToForward: 94}, + routeThreeHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, + {pubKeyBytes: hops[2], amtToFwd: 97}, + {pubKeyBytes: hops[3], amtToFwd: 94}, }, } - routeFourHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, - {PubKeyBytes: hops[2], AmtToForward: 97}, - {PubKeyBytes: hops[3], AmtToForward: 94}, - {PubKeyBytes: hops[4], AmtToForward: 90}, + routeFourHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, + {pubKeyBytes: hops[2], amtToFwd: 97}, + {pubKeyBytes: hops[3], amtToFwd: 94}, + {pubKeyBytes: hops[4], amtToFwd: 90}, }, } // blindedMultiHop is a blinded path where there are cleartext hops // before the introduction node, and an intermediate blinded hop before // the recipient after it. - blindedMultiHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, + blindedMultiHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, { - PubKeyBytes: hops[2], - AmtToForward: 95, - BlindingPoint: blindingPoint, + pubKeyBytes: hops[2], + amtToFwd: 95, + hasBlindingPoint: true, }, - {PubKeyBytes: hops[3], AmtToForward: 88}, - {PubKeyBytes: hops[4], AmtToForward: 77}, + {pubKeyBytes: hops[3], amtToFwd: 88}, + {pubKeyBytes: hops[4], amtToFwd: 77}, }, } // blindedSingleHop is a blinded path with a single blinded hop after // the introduction node. - blindedSingleHop = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 99}, + blindedSingleHop = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 99}, { - PubKeyBytes: hops[2], - AmtToForward: 95, - BlindingPoint: blindingPoint, + pubKeyBytes: hops[2], + amtToFwd: 95, + hasBlindingPoint: true, }, - {PubKeyBytes: hops[3], AmtToForward: 88}, + {pubKeyBytes: hops[3], amtToFwd: 88}, }, } // blindedMultiToIntroduction is a blinded path which goes directly // to the introduction node, with multiple blinded hops after it. - blindedMultiToIntroduction = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ + blindedMultiToIntroduction = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ { - PubKeyBytes: hops[1], - AmtToForward: 90, - BlindingPoint: blindingPoint, + pubKeyBytes: hops[1], + amtToFwd: 90, + hasBlindingPoint: true, }, - {PubKeyBytes: hops[2], AmtToForward: 75}, - {PubKeyBytes: hops[3], AmtToForward: 58}, + {pubKeyBytes: hops[2], amtToFwd: 75}, + {pubKeyBytes: hops[3], amtToFwd: 58}, }, } // blindedIntroReceiver is a blinded path where the introduction node // is the recipient. - blindedIntroReceiver = route.Route{ - SourcePubKey: hops[0], - TotalAmount: 100, - Hops: []*route.Hop{ - {PubKeyBytes: hops[1], AmtToForward: 95}, + blindedIntroReceiver = mcRoute{ + sourcePubKey: hops[0], + totalAmount: 100, + hops: []*mcHop{ + {pubKeyBytes: hops[1], amtToFwd: 95}, { - PubKeyBytes: hops[2], - AmtToForward: 90, - BlindingPoint: blindingPoint, + pubKeyBytes: hops[2], + amtToFwd: 90, + hasBlindingPoint: true, }, }, } @@ -134,7 +129,7 @@ func getPolicyFailure(from, to int) *DirectedNodePair { type resultTestCase struct { name string - route *route.Route + route *mcRoute success bool failureSrcIdx int failure lnwire.FailureMessage @@ -159,7 +154,7 @@ var resultTestCases = []resultTestCase{ }, }, - // Tests that a expiry too soon failure result is properly interpreted. + // Tests that an expiry too soon failure result is properly interpreted. { name: "fail expiry too soon", route: &routeFourHop,