From 3897e9159ce8ff03084c36094c28b3ee059c0132 Mon Sep 17 00:00:00 2001 From: Alexandros Filios Date: Thu, 16 Jan 2025 08:43:02 +0100 Subject: [PATCH] Performance enhancements: Flatten binding hierarchy. Now all bindings point to the long-term id Signed-off-by: Alexandros Filios --- platform/common/driver/kvs.go | 1 + platform/common/utils/collections/maps.go | 9 + platform/view/core/endpoint/binder.go | 67 +++++ platform/view/core/endpoint/endpoint.go | 129 ++++----- platform/view/core/endpoint/endpoint_test.go | 5 +- platform/view/core/manager/context.go | 12 +- platform/view/driver/endpointservice.go | 12 +- platform/view/driver/mock/resolver.go | 272 +++++++++--------- platform/view/endpoint.go | 11 +- platform/view/sdk/dig/sdk.go | 1 + .../comm/host/rest/routing/idrouter.go | 13 +- platform/view/services/kvs/scope.go | 11 + 12 files changed, 326 insertions(+), 217 deletions(-) create mode 100644 platform/view/core/endpoint/binder.go diff --git a/platform/common/driver/kvs.go b/platform/common/driver/kvs.go index 0c1c92c2f..723327105 100644 --- a/platform/common/driver/kvs.go +++ b/platform/common/driver/kvs.go @@ -28,6 +28,7 @@ type AuditInfoKVS interface { type BindingKVS interface { GetBinding(ephemeral view.Identity) (view.Identity, error) + HaveSameBinding(this, that view.Identity) (bool, error) PutBinding(ephemeral, longTerm view.Identity) error } diff --git a/platform/common/utils/collections/maps.go b/platform/common/utils/collections/maps.go index 67df7ab2a..d76f82e1b 100644 --- a/platform/common/utils/collections/maps.go +++ b/platform/common/utils/collections/maps.go @@ -34,6 +34,15 @@ func Values[K comparable, V any](m map[K]V) []V { return res } +func ContainsValue[K, V comparable](haystack map[K]V, needle V) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + return false +} + func Keys[K comparable, V any](m map[K]V) []K { res := make([]K, len(m)) i := 0 diff --git a/platform/view/core/endpoint/binder.go b/platform/view/core/endpoint/binder.go new file mode 100644 index 000000000..4aa4db953 --- /dev/null +++ b/platform/view/core/endpoint/binder.go @@ -0,0 +1,67 @@ +/* +Copyright IBM Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package endpoint + +import ( + "github.com/hyperledger-labs/fabric-smart-client/platform/common/driver" + "github.com/hyperledger-labs/fabric-smart-client/platform/view/view" + "github.com/pkg/errors" +) + +type Binder interface { + GetLongTerm(ephemeral view.Identity) (view.Identity, error) + Bind(ephemeral, other view.Identity) error + IsBoundTo(this, that view.Identity) (bool, error) +} + +func NewBinder(bindingKVS driver.BindingKVS) *binder { + return &binder{bindingKVS: bindingKVS} +} + +type binder struct { + bindingKVS driver.BindingKVS +} + +func (b *binder) Bind(this, that view.Identity) error { + if this.IsNone() || that.IsNone() { + return errors.New("empty ids passed") + } + // Make sure the long term is passed, so that the hierarchy is flat and all ephemerals point to the long term + longTerm, err := b.bindingKVS.GetBinding(that) + if err != nil { + return errors.Wrapf(err, "no long term found for [%s]. if long term was passed, it has to be registered first.", that) + } + if !longTerm.IsNone() { + logger.Debugf("Long term id for [%s] is [%s]", that, longTerm) + return b.bindingKVS.PutBinding(this, longTerm) + } + + logger.Debugf("Id [%s] has no long term binding. It will be registered as a long-term id.", that) + if err := b.bindingKVS.PutBinding(that, that); err != nil { + return errors.Wrapf(err, "failed to register [%s] as long term", that) + } + err = b.bindingKVS.PutBinding(this, that) + if lt, err := b.bindingKVS.GetBinding(this); err != nil || lt.IsNone() { + logger.Errorf("wrong binding for [%s][%s]: %v", that, lt, err) + } else { + logger.Errorf("successful binding for [%s]", that) + } + return err +} + +func (b *binder) RegisterLongTerm(longTerm view.Identity) error { + // Self reference that indicates that a binding is long term + return b.bindingKVS.PutBinding(longTerm, longTerm) +} + +func (b *binder) GetLongTerm(ephemeral view.Identity) (view.Identity, error) { + return b.bindingKVS.GetBinding(ephemeral) +} + +func (b *binder) IsBoundTo(this, that view.Identity) (bool, error) { + return b.bindingKVS.HaveSameBinding(this, that) +} diff --git a/platform/view/core/endpoint/endpoint.go b/platform/view/core/endpoint/endpoint.go index 8f4f85439..d681941b5 100644 --- a/platform/view/core/endpoint/endpoint.go +++ b/platform/view/core/endpoint/endpoint.go @@ -13,8 +13,8 @@ import ( "strings" "sync" - driver2 "github.com/hyperledger-labs/fabric-smart-client/platform/common/driver" "github.com/hyperledger-labs/fabric-smart-client/platform/common/services/logging" + "github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections" "github.com/hyperledger-labs/fabric-smart-client/platform/view/driver" "github.com/hyperledger-labs/fabric-smart-client/platform/view/view" "github.com/pkg/errors" @@ -35,6 +35,14 @@ type Resolver struct { IdentityGetter func() (view.Identity, []byte, error) } +func (r *Resolver) GetName() string { return r.Name } + +func (r *Resolver) GetId() view.Identity { return r.Id } + +func (r *Resolver) GetAddress(port driver.PortName) string { return r.Addresses[port] } + +func (r *Resolver) GetAddresses() map[driver.PortName]string { return r.Addresses } + func (r *Resolver) GetIdentity() (view.Identity, error) { if r.IdentityGetter != nil { id, _, err := r.IdentityGetter() @@ -58,7 +66,7 @@ type Discovery interface { type Service struct { resolvers []*Resolver resolversMutex sync.RWMutex - bindingKVS driver2.BindingKVS + binder Binder pkiExtractorsLock sync.RWMutex publicKeyExtractors []driver.PublicKeyExtractor @@ -66,44 +74,45 @@ type Service struct { } // NewService returns a new instance of the view-sdk endpoint service -func NewService(bindingKVS driver2.BindingKVS) (*Service, error) { +func NewService(binder Binder) (*Service, error) { er := &Service{ - bindingKVS: bindingKVS, + binder: binder, publicKeyExtractors: []driver.PublicKeyExtractor{}, publicKeyIDSynthesizer: DefaultPublicKeyIDSynthesizer{}, } return er, nil } -func (r *Service) Endpoint(party view.Identity) (map[driver.PortName]string, error) { - _, e, _, err := r.resolve(party) - return e, err -} - -func (r *Service) Resolve(party view.Identity) (string, view.Identity, map[driver.PortName]string, []byte, error) { - cursor, e, resolver, err := r.resolve(party) +func (r *Service) Resolve(party view.Identity) (driver.Resolver, []byte, error) { + resolver, err := r.resolver(party) if err != nil { - return "", nil, nil, nil, err + return nil, nil, err } - return resolver.Name, cursor, e, r.pkiResolve(resolver), nil + return resolver, r.pkiResolve(resolver), nil } -func (r *Service) resolve(party view.Identity) (view.Identity, map[driver.PortName]string, *Resolver, error) { - cursor := party - for { - // root endpoints have addresses - // is this a root endpoint - resolver, e, err := r.rootEndpoint(cursor) - if err == nil { - return cursor, e, resolver, nil - } - logger.Debugf("resolving via binding for %s", cursor) - cursor, err = r.bindingKVS.GetBinding(cursor) - if err != nil { - return nil, nil, nil, err - } - logger.Debugf("continue to [%s]", cursor) +func (r *Service) GetResolver(party view.Identity) (driver.Resolver, error) { + return r.resolver(party) +} + +func (r *Service) resolver(party view.Identity) (*Resolver, error) { + // We can skip this check, but in case the long term was passed directly, this is going to spare us a DB lookup + resolver, err := r.rootEndpoint(party) + if err == nil { + return resolver, nil + } + logger.Debugf("resolving via binding for %s", party) + party, err = r.binder.GetLongTerm(party) + if err != nil { + return nil, err } + logger.Debugf("continue to [%s]", party) + resolver, err = r.rootEndpoint(party) + if err != nil { + return nil, errors.Wrapf(err, "failed getting identity for [%s]", party) + } + + return resolver, nil } func (r *Service) Bind(longTerm view.Identity, ephemeral view.Identity) error { @@ -114,7 +123,7 @@ func (r *Service) Bind(longTerm view.Identity, ephemeral view.Identity) error { logger.Debugf("bind [%s] to [%s]", ephemeral, longTerm) - if err := r.bindingKVS.PutBinding(ephemeral, longTerm); err != nil { + if err := r.binder.Bind(ephemeral, longTerm); err != nil { return errors.WithMessagef(err, "failed storing binding of [%s] to [%s]", ephemeral.UniqueID(), longTerm.UniqueID()) } @@ -122,19 +131,11 @@ func (r *Service) Bind(longTerm view.Identity, ephemeral view.Identity) error { } func (r *Service) IsBoundTo(a view.Identity, b view.Identity) bool { - for { - if a.Equal(b) { - return true - } - next, err := r.bindingKVS.GetBinding(a) - if err != nil { - return false - } - if next.Equal(b) { - return true - } - a = next + ok, err := r.binder.IsBoundTo(a, b) + if err != nil { + logger.Errorf("error fetching entries [%s] and [%s]: %v", a, b, err) } + return ok } func (r *Service) GetIdentity(endpoint string, pkID []byte) (view.Identity, error) { @@ -143,37 +144,25 @@ func (r *Service) GetIdentity(endpoint string, pkID []byte) (view.Identity, erro // search in the resolver list for _, resolver := range r.resolvers { - resolverPKID := r.pkiResolve(resolver) - found := false - for _, addr := range resolver.Addresses { - if endpoint == addr { - found = true - break - } - } - if !found { - // check aliases - found = slices.Contains(resolver.Aliases, endpoint) - } - if endpoint == resolver.Name || - found || - endpoint == resolver.Name+"."+resolver.Domain || - bytes.Equal(pkID, resolver.Id) || - bytes.Equal(pkID, resolverPKID) { - - id, err := resolver.GetIdentity() - if err != nil { - return nil, err - } - if logger.IsEnabledFor(zapcore.DebugLevel) { - logger.Debugf("resolving [%s,%s] to %s", endpoint, view.Identity(pkID), id) - } - return id, nil + if r.matchesResolver(endpoint, pkID, resolver) { + return resolver.GetIdentity() } } return nil, errors.Errorf("identity not found at [%s,%s]", endpoint, view.Identity(pkID)) } +func (r *Service) matchesResolver(endpoint string, pkID []byte, resolver *Resolver) bool { + if len(endpoint) > 0 && (endpoint == resolver.Name || + endpoint == resolver.Name+"."+resolver.Domain || + collections.ContainsValue(resolver.Addresses, endpoint) || + slices.Contains(resolver.Aliases, endpoint)) { + return true + } + + return len(pkID) > 0 && (bytes.Equal(pkID, resolver.Id) || + bytes.Equal(pkID, r.pkiResolve(resolver))) +} + func (r *Service) AddResolver(name string, domain string, addresses map[string]string, aliases []string, id []byte) (view.Identity, error) { if logger.IsEnabledFor(zapcore.DebugLevel) { logger.Debugf("adding resolver [%s,%s,%v,%v,%s]", name, domain, addresses, aliases, view.Identity(id).String()) @@ -265,17 +254,17 @@ func (r *Service) ExtractPKI(id []byte) []byte { return nil } -func (r *Service) rootEndpoint(party view.Identity) (*Resolver, map[driver.PortName]string, error) { +func (r *Service) rootEndpoint(party view.Identity) (*Resolver, error) { r.resolversMutex.RLock() defer r.resolversMutex.RUnlock() for _, resolver := range r.resolvers { if bytes.Equal(resolver.Id, party) { - return resolver, resolver.Addresses, nil + return resolver, nil } } - return nil, nil, errors.Errorf("endpoint not found for identity %s", party.UniqueID()) + return nil, errors.Errorf("endpoint not found for identity %s", party.UniqueID()) } var portNameMap = map[string]driver.PortName{ diff --git a/platform/view/core/endpoint/endpoint_test.go b/platform/view/core/endpoint/endpoint_test.go index eb3208f62..72b02d24e 100644 --- a/platform/view/core/endpoint/endpoint_test.go +++ b/platform/view/core/endpoint/endpoint_test.go @@ -17,6 +17,7 @@ import ( type mockKVS struct{} func (k mockKVS) GetBinding(ephemeral view.Identity) (view.Identity, error) { return nil, nil } +func (k mockKVS) HaveSameBinding(this, that view.Identity) (bool, error) { return false, nil } func (k mockKVS) PutBinding(ephemeral, longTerm view.Identity) error { return nil } type mockExtractor struct{} @@ -26,7 +27,7 @@ func (m mockExtractor) ExtractPublicKey(id view.Identity) (any, error) { } func TestPKIResolveConcurrency(t *testing.T) { - svc, err := NewService(mockKVS{}) + svc, err := NewService(NewBinder(mockKVS{})) assert.NoError(err) ext := mockExtractor{} @@ -47,7 +48,7 @@ func TestPKIResolveConcurrency(t *testing.T) { } func TestGetIdentity(t *testing.T) { - svc, err := NewService(mockKVS{}) + svc, err := NewService(NewBinder(mockKVS{})) assert.NoError(err) ext := mockExtractor{} diff --git a/platform/view/core/manager/context.go b/platform/view/core/manager/context.go index c3b62a8d8..f16c71512 100644 --- a/platform/view/core/manager/context.go +++ b/platform/view/core/manager/context.go @@ -334,19 +334,23 @@ func (ctx *ctx) Dispose() { } func (ctx *ctx) newSession(view view.View, contextID string, party view.Identity) (view.Session, error) { - _, _, endpoints, pkid, err := ctx.resolver.Resolve(party) + resolver, pkid, err := ctx.resolver.Resolve(party) if err != nil { return nil, err } - return ctx.sessionFactory.NewSession(getIdentifier(view), contextID, endpoints[driver.P2PPort], pkid) + return ctx.sessionFactory.NewSession(getIdentifier(view), contextID, resolver.GetAddress(driver.P2PPort), pkid) } func (ctx *ctx) newSessionByID(sessionID, contextID string, party view.Identity) (view.Session, error) { - _, _, endpoints, pkid, err := ctx.resolver.Resolve(party) + resolver, pkid, err := ctx.resolver.Resolve(party) if err != nil { return nil, err } - return ctx.sessionFactory.NewSessionWithID(sessionID, contextID, endpoints[driver.P2PPort], pkid, nil, nil) + var endpoint string + if resolver != nil { + endpoint = resolver.GetAddress(driver.P2PPort) + } + return ctx.sessionFactory.NewSessionWithID(sessionID, contextID, endpoint, pkid, nil, nil) } func (ctx *ctx) cleanup() { diff --git a/platform/view/driver/endpointservice.go b/platform/view/driver/endpointservice.go index 67363c2c9..c109f620c 100644 --- a/platform/view/driver/endpointservice.go +++ b/platform/view/driver/endpointservice.go @@ -36,12 +36,20 @@ type PublicKeyIDSynthesizer interface { //go:generate counterfeiter -o mock/resolver.go -fake-name EndpointService . EndpointService +type Resolver interface { + GetName() string + GetId() view.Identity + GetAddress(port PortName) string + GetAddresses() map[PortName]string +} + // EndpointService models the endpoint service type EndpointService interface { // Resolve returns the identity the passed identity is bound to. // It returns also: the endpoints and the pkiID - Resolve(party view.Identity) (string, view.Identity, map[PortName]string, []byte, error) - + Resolve(party view.Identity) (Resolver, []byte, error) + // GetResolver returns the identity the passed identity is bound to + GetResolver(party view.Identity) (Resolver, error) // GetIdentity returns an identity bound to either the passed label or public-key identifier. GetIdentity(label string, pkiID []byte) (view.Identity, error) diff --git a/platform/view/driver/mock/resolver.go b/platform/view/driver/mock/resolver.go index 0047d731b..5dfa17cc2 100644 --- a/platform/view/driver/mock/resolver.go +++ b/platform/view/driver/mock/resolver.go @@ -4,8 +4,8 @@ package mock import ( "sync" + "github.com/hyperledger-labs/fabric-smart-client/platform/common/services/identity" "github.com/hyperledger-labs/fabric-smart-client/platform/view/driver" - "github.com/hyperledger-labs/fabric-smart-client/platform/view/view" ) type EndpointService struct { @@ -20,7 +20,7 @@ type EndpointService struct { addPublicKeyExtractorReturnsOnCall map[int]struct { result1 error } - AddResolverStub func(string, string, map[string]string, []string, []byte) (view.Identity, error) + AddResolverStub func(string, string, map[string]string, []string, []byte) (identity.Identity, error) addResolverMutex sync.RWMutex addResolverArgsForCall []struct { arg1 string @@ -30,18 +30,18 @@ type EndpointService struct { arg5 []byte } addResolverReturns struct { - result1 view.Identity + result1 identity.Identity result2 error } addResolverReturnsOnCall map[int]struct { - result1 view.Identity + result1 identity.Identity result2 error } - BindStub func(view.Identity, view.Identity) error + BindStub func(identity.Identity, identity.Identity) error bindMutex sync.RWMutex bindArgsForCall []struct { - arg1 view.Identity - arg2 view.Identity + arg1 identity.Identity + arg2 identity.Identity } bindReturns struct { result1 error @@ -49,38 +49,38 @@ type EndpointService struct { bindReturnsOnCall map[int]struct { result1 error } - EndpointStub func(view.Identity) (map[driver.PortName]string, error) - endpointMutex sync.RWMutex - endpointArgsForCall []struct { - arg1 view.Identity - } - endpointReturns struct { - result1 map[driver.PortName]string - result2 error - } - endpointReturnsOnCall map[int]struct { - result1 map[driver.PortName]string - result2 error - } - GetIdentityStub func(string, []byte) (view.Identity, error) + GetIdentityStub func(string, []byte) (identity.Identity, error) getIdentityMutex sync.RWMutex getIdentityArgsForCall []struct { arg1 string arg2 []byte } getIdentityReturns struct { - result1 view.Identity + result1 identity.Identity result2 error } getIdentityReturnsOnCall map[int]struct { - result1 view.Identity + result1 identity.Identity + result2 error + } + GetResolverStub func(identity.Identity) (driver.Resolver, error) + getResolverMutex sync.RWMutex + getResolverArgsForCall []struct { + arg1 identity.Identity + } + getResolverReturns struct { + result1 driver.Resolver + result2 error + } + getResolverReturnsOnCall map[int]struct { + result1 driver.Resolver result2 error } - IsBoundToStub func(view.Identity, view.Identity) bool + IsBoundToStub func(identity.Identity, identity.Identity) bool isBoundToMutex sync.RWMutex isBoundToArgsForCall []struct { - arg1 view.Identity - arg2 view.Identity + arg1 identity.Identity + arg2 identity.Identity } isBoundToReturns struct { result1 bool @@ -88,22 +88,20 @@ type EndpointService struct { isBoundToReturnsOnCall map[int]struct { result1 bool } - ResolveStub func(view.Identity) (string, view.Identity, map[driver.PortName]string, []byte, error) + ResolveStub func(identity.Identity) (driver.Resolver, []byte, error) resolveMutex sync.RWMutex resolveArgsForCall []struct { - arg1 view.Identity + arg1 identity.Identity } resolveReturns struct { - result1 view.Identity - result2 map[driver.PortName]string - result3 []byte - result4 error + result1 driver.Resolver + result2 []byte + result3 error } resolveReturnsOnCall map[int]struct { - result1 view.Identity - result2 map[driver.PortName]string - result3 []byte - result4 error + result1 driver.Resolver + result2 []byte + result3 error } SetPublicKeyIDSynthesizerStub func(driver.PublicKeyIDSynthesizer) setPublicKeyIDSynthesizerMutex sync.RWMutex @@ -175,7 +173,7 @@ func (fake *EndpointService) AddPublicKeyExtractorReturnsOnCall(i int, result1 e }{result1} } -func (fake *EndpointService) AddResolver(arg1 string, arg2 string, arg3 map[string]string, arg4 []string, arg5 []byte) (view.Identity, error) { +func (fake *EndpointService) AddResolver(arg1 string, arg2 string, arg3 map[string]string, arg4 []string, arg5 []byte) (identity.Identity, error) { var arg4Copy []string if arg4 != nil { arg4Copy = make([]string, len(arg4)) @@ -214,7 +212,7 @@ func (fake *EndpointService) AddResolverCallCount() int { return len(fake.addResolverArgsForCall) } -func (fake *EndpointService) AddResolverCalls(stub func(string, string, map[string]string, []string, []byte) (view.Identity, error)) { +func (fake *EndpointService) AddResolverCalls(stub func(string, string, map[string]string, []string, []byte) (identity.Identity, error)) { fake.addResolverMutex.Lock() defer fake.addResolverMutex.Unlock() fake.AddResolverStub = stub @@ -227,38 +225,38 @@ func (fake *EndpointService) AddResolverArgsForCall(i int) (string, string, map[ return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 } -func (fake *EndpointService) AddResolverReturns(result1 view.Identity, result2 error) { +func (fake *EndpointService) AddResolverReturns(result1 identity.Identity, result2 error) { fake.addResolverMutex.Lock() defer fake.addResolverMutex.Unlock() fake.AddResolverStub = nil fake.addResolverReturns = struct { - result1 view.Identity + result1 identity.Identity result2 error }{result1, result2} } -func (fake *EndpointService) AddResolverReturnsOnCall(i int, result1 view.Identity, result2 error) { +func (fake *EndpointService) AddResolverReturnsOnCall(i int, result1 identity.Identity, result2 error) { fake.addResolverMutex.Lock() defer fake.addResolverMutex.Unlock() fake.AddResolverStub = nil if fake.addResolverReturnsOnCall == nil { fake.addResolverReturnsOnCall = make(map[int]struct { - result1 view.Identity + result1 identity.Identity result2 error }) } fake.addResolverReturnsOnCall[i] = struct { - result1 view.Identity + result1 identity.Identity result2 error }{result1, result2} } -func (fake *EndpointService) Bind(arg1 view.Identity, arg2 view.Identity) error { +func (fake *EndpointService) Bind(arg1 identity.Identity, arg2 identity.Identity) error { fake.bindMutex.Lock() ret, specificReturn := fake.bindReturnsOnCall[len(fake.bindArgsForCall)] fake.bindArgsForCall = append(fake.bindArgsForCall, struct { - arg1 view.Identity - arg2 view.Identity + arg1 identity.Identity + arg2 identity.Identity }{arg1, arg2}) stub := fake.BindStub fakeReturns := fake.bindReturns @@ -279,13 +277,13 @@ func (fake *EndpointService) BindCallCount() int { return len(fake.bindArgsForCall) } -func (fake *EndpointService) BindCalls(stub func(view.Identity, view.Identity) error) { +func (fake *EndpointService) BindCalls(stub func(identity.Identity, identity.Identity) error) { fake.bindMutex.Lock() defer fake.bindMutex.Unlock() fake.BindStub = stub } -func (fake *EndpointService) BindArgsForCall(i int) (view.Identity, view.Identity) { +func (fake *EndpointService) BindArgsForCall(i int) (identity.Identity, identity.Identity) { fake.bindMutex.RLock() defer fake.bindMutex.RUnlock() argsForCall := fake.bindArgsForCall[i] @@ -315,52 +313,7 @@ func (fake *EndpointService) BindReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *EndpointService) EndpointCallCount() int { - fake.endpointMutex.RLock() - defer fake.endpointMutex.RUnlock() - return len(fake.endpointArgsForCall) -} - -func (fake *EndpointService) EndpointCalls(stub func(view.Identity) (map[driver.PortName]string, error)) { - fake.endpointMutex.Lock() - defer fake.endpointMutex.Unlock() - fake.EndpointStub = stub -} - -func (fake *EndpointService) EndpointArgsForCall(i int) view.Identity { - fake.endpointMutex.RLock() - defer fake.endpointMutex.RUnlock() - argsForCall := fake.endpointArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *EndpointService) EndpointReturns(result1 map[driver.PortName]string, result2 error) { - fake.endpointMutex.Lock() - defer fake.endpointMutex.Unlock() - fake.EndpointStub = nil - fake.endpointReturns = struct { - result1 map[driver.PortName]string - result2 error - }{result1, result2} -} - -func (fake *EndpointService) EndpointReturnsOnCall(i int, result1 map[driver.PortName]string, result2 error) { - fake.endpointMutex.Lock() - defer fake.endpointMutex.Unlock() - fake.EndpointStub = nil - if fake.endpointReturnsOnCall == nil { - fake.endpointReturnsOnCall = make(map[int]struct { - result1 map[driver.PortName]string - result2 error - }) - } - fake.endpointReturnsOnCall[i] = struct { - result1 map[driver.PortName]string - result2 error - }{result1, result2} -} - -func (fake *EndpointService) GetIdentity(arg1 string, arg2 []byte) (view.Identity, error) { +func (fake *EndpointService) GetIdentity(arg1 string, arg2 []byte) (identity.Identity, error) { var arg2Copy []byte if arg2 != nil { arg2Copy = make([]byte, len(arg2)) @@ -391,7 +344,7 @@ func (fake *EndpointService) GetIdentityCallCount() int { return len(fake.getIdentityArgsForCall) } -func (fake *EndpointService) GetIdentityCalls(stub func(string, []byte) (view.Identity, error)) { +func (fake *EndpointService) GetIdentityCalls(stub func(string, []byte) (identity.Identity, error)) { fake.getIdentityMutex.Lock() defer fake.getIdentityMutex.Unlock() fake.GetIdentityStub = stub @@ -404,38 +357,102 @@ func (fake *EndpointService) GetIdentityArgsForCall(i int) (string, []byte) { return argsForCall.arg1, argsForCall.arg2 } -func (fake *EndpointService) GetIdentityReturns(result1 view.Identity, result2 error) { +func (fake *EndpointService) GetIdentityReturns(result1 identity.Identity, result2 error) { fake.getIdentityMutex.Lock() defer fake.getIdentityMutex.Unlock() fake.GetIdentityStub = nil fake.getIdentityReturns = struct { - result1 view.Identity + result1 identity.Identity result2 error }{result1, result2} } -func (fake *EndpointService) GetIdentityReturnsOnCall(i int, result1 view.Identity, result2 error) { +func (fake *EndpointService) GetIdentityReturnsOnCall(i int, result1 identity.Identity, result2 error) { fake.getIdentityMutex.Lock() defer fake.getIdentityMutex.Unlock() fake.GetIdentityStub = nil if fake.getIdentityReturnsOnCall == nil { fake.getIdentityReturnsOnCall = make(map[int]struct { - result1 view.Identity + result1 identity.Identity result2 error }) } fake.getIdentityReturnsOnCall[i] = struct { - result1 view.Identity + result1 identity.Identity + result2 error + }{result1, result2} +} + +func (fake *EndpointService) GetResolver(arg1 identity.Identity) (driver.Resolver, error) { + fake.getResolverMutex.Lock() + ret, specificReturn := fake.getResolverReturnsOnCall[len(fake.getResolverArgsForCall)] + fake.getResolverArgsForCall = append(fake.getResolverArgsForCall, struct { + arg1 identity.Identity + }{arg1}) + stub := fake.GetResolverStub + fakeReturns := fake.getResolverReturns + fake.recordInvocation("GetResolver", []interface{}{arg1}) + fake.getResolverMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *EndpointService) GetResolverCallCount() int { + fake.getResolverMutex.RLock() + defer fake.getResolverMutex.RUnlock() + return len(fake.getResolverArgsForCall) +} + +func (fake *EndpointService) GetResolverCalls(stub func(identity.Identity) (driver.Resolver, error)) { + fake.getResolverMutex.Lock() + defer fake.getResolverMutex.Unlock() + fake.GetResolverStub = stub +} + +func (fake *EndpointService) GetResolverArgsForCall(i int) identity.Identity { + fake.getResolverMutex.RLock() + defer fake.getResolverMutex.RUnlock() + argsForCall := fake.getResolverArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *EndpointService) GetResolverReturns(result1 driver.Resolver, result2 error) { + fake.getResolverMutex.Lock() + defer fake.getResolverMutex.Unlock() + fake.GetResolverStub = nil + fake.getResolverReturns = struct { + result1 driver.Resolver + result2 error + }{result1, result2} +} + +func (fake *EndpointService) GetResolverReturnsOnCall(i int, result1 driver.Resolver, result2 error) { + fake.getResolverMutex.Lock() + defer fake.getResolverMutex.Unlock() + fake.GetResolverStub = nil + if fake.getResolverReturnsOnCall == nil { + fake.getResolverReturnsOnCall = make(map[int]struct { + result1 driver.Resolver + result2 error + }) + } + fake.getResolverReturnsOnCall[i] = struct { + result1 driver.Resolver result2 error }{result1, result2} } -func (fake *EndpointService) IsBoundTo(arg1 view.Identity, arg2 view.Identity) bool { +func (fake *EndpointService) IsBoundTo(arg1 identity.Identity, arg2 identity.Identity) bool { fake.isBoundToMutex.Lock() ret, specificReturn := fake.isBoundToReturnsOnCall[len(fake.isBoundToArgsForCall)] fake.isBoundToArgsForCall = append(fake.isBoundToArgsForCall, struct { - arg1 view.Identity - arg2 view.Identity + arg1 identity.Identity + arg2 identity.Identity }{arg1, arg2}) stub := fake.IsBoundToStub fakeReturns := fake.isBoundToReturns @@ -456,13 +473,13 @@ func (fake *EndpointService) IsBoundToCallCount() int { return len(fake.isBoundToArgsForCall) } -func (fake *EndpointService) IsBoundToCalls(stub func(view.Identity, view.Identity) bool) { +func (fake *EndpointService) IsBoundToCalls(stub func(identity.Identity, identity.Identity) bool) { fake.isBoundToMutex.Lock() defer fake.isBoundToMutex.Unlock() fake.IsBoundToStub = stub } -func (fake *EndpointService) IsBoundToArgsForCall(i int) (view.Identity, view.Identity) { +func (fake *EndpointService) IsBoundToArgsForCall(i int) (identity.Identity, identity.Identity) { fake.isBoundToMutex.RLock() defer fake.isBoundToMutex.RUnlock() argsForCall := fake.isBoundToArgsForCall[i] @@ -492,11 +509,11 @@ func (fake *EndpointService) IsBoundToReturnsOnCall(i int, result1 bool) { }{result1} } -func (fake *EndpointService) Resolve(arg1 view.Identity) (string, view.Identity, map[driver.PortName]string, []byte, error) { +func (fake *EndpointService) Resolve(arg1 identity.Identity) (driver.Resolver, []byte, error) { fake.resolveMutex.Lock() ret, specificReturn := fake.resolveReturnsOnCall[len(fake.resolveArgsForCall)] fake.resolveArgsForCall = append(fake.resolveArgsForCall, struct { - arg1 view.Identity + arg1 identity.Identity }{arg1}) stub := fake.ResolveStub fakeReturns := fake.resolveReturns @@ -506,9 +523,9 @@ func (fake *EndpointService) Resolve(arg1 view.Identity) (string, view.Identity, return stub(arg1) } if specificReturn { - return "", ret.result1, ret.result2, ret.result3, ret.result4 + return ret.result1, ret.result2, ret.result3 } - return "", fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 } func (fake *EndpointService) ResolveCallCount() int { @@ -517,49 +534,46 @@ func (fake *EndpointService) ResolveCallCount() int { return len(fake.resolveArgsForCall) } -func (fake *EndpointService) ResolveCalls(stub func(view.Identity) (string, view.Identity, map[driver.PortName]string, []byte, error)) { +func (fake *EndpointService) ResolveCalls(stub func(identity.Identity) (driver.Resolver, []byte, error)) { fake.resolveMutex.Lock() defer fake.resolveMutex.Unlock() fake.ResolveStub = stub } -func (fake *EndpointService) ResolveArgsForCall(i int) view.Identity { +func (fake *EndpointService) ResolveArgsForCall(i int) identity.Identity { fake.resolveMutex.RLock() defer fake.resolveMutex.RUnlock() argsForCall := fake.resolveArgsForCall[i] return argsForCall.arg1 } -func (fake *EndpointService) ResolveReturns(result1 view.Identity, result2 map[driver.PortName]string, result3 []byte, result4 error) { +func (fake *EndpointService) ResolveReturns(result1 driver.Resolver, result2 []byte, result3 error) { fake.resolveMutex.Lock() defer fake.resolveMutex.Unlock() fake.ResolveStub = nil fake.resolveReturns = struct { - result1 view.Identity - result2 map[driver.PortName]string - result3 []byte - result4 error - }{result1, result2, result3, result4} + result1 driver.Resolver + result2 []byte + result3 error + }{result1, result2, result3} } -func (fake *EndpointService) ResolveReturnsOnCall(i int, result1 view.Identity, result2 map[driver.PortName]string, result3 []byte, result4 error) { +func (fake *EndpointService) ResolveReturnsOnCall(i int, result1 driver.Resolver, result2 []byte, result3 error) { fake.resolveMutex.Lock() defer fake.resolveMutex.Unlock() fake.ResolveStub = nil if fake.resolveReturnsOnCall == nil { fake.resolveReturnsOnCall = make(map[int]struct { - result1 view.Identity - result2 map[driver.PortName]string - result3 []byte - result4 error + result1 driver.Resolver + result2 []byte + result3 error }) } fake.resolveReturnsOnCall[i] = struct { - result1 view.Identity - result2 map[driver.PortName]string - result3 []byte - result4 error - }{result1, result2, result3, result4} + result1 driver.Resolver + result2 []byte + result3 error + }{result1, result2, result3} } func (fake *EndpointService) SetPublicKeyIDSynthesizer(arg1 driver.PublicKeyIDSynthesizer) { @@ -603,10 +617,10 @@ func (fake *EndpointService) Invocations() map[string][][]interface{} { defer fake.addResolverMutex.RUnlock() fake.bindMutex.RLock() defer fake.bindMutex.RUnlock() - fake.endpointMutex.RLock() - defer fake.endpointMutex.RUnlock() fake.getIdentityMutex.RLock() defer fake.getIdentityMutex.RUnlock() + fake.getResolverMutex.RLock() + defer fake.getResolverMutex.RUnlock() fake.isBoundToMutex.RLock() defer fake.isBoundToMutex.RUnlock() fake.resolveMutex.RLock() diff --git a/platform/view/endpoint.go b/platform/view/endpoint.go index 23afc3612..497a966ed 100644 --- a/platform/view/endpoint.go +++ b/platform/view/endpoint.go @@ -47,18 +47,21 @@ func NewEndpointService(es driver.EndpointService) *EndpointService { // If the passed identity does not have any endpoint set, the service checks // if the passed identity is bound to another identity that is returned together with its endpoints and public-key identifier. func (e *EndpointService) Resolve(party view.Identity) (view.Identity, map[PortName]string, []byte, error) { - _, id, ports, raw, err := e.es.Resolve(party) + resolver, raw, err := e.es.Resolve(party) if err != nil { return nil, nil, nil, err } + if resolver == nil { + return nil, nil, raw, nil + } if logger.IsEnabledFor(zapcore.DebugLevel) { - logger.Debugf("resolved [%s] to [%s] with ports [%v]", party, id, ports) + logger.Debugf("resolved [%s] to [%s] with ports [%v]", party, resolver.GetId(), resolver.GetAddresses()) } out := map[PortName]string{} - for name, s := range ports { + for name, s := range resolver.GetAddresses() { out[PortName(name)] = s } - return id, out, raw, nil + return resolver.GetId(), out, raw, nil } func (e *EndpointService) ResolveIdentities(endpoints ...string) ([]view.Identity, error) { diff --git a/platform/view/sdk/dig/sdk.go b/platform/view/sdk/dig/sdk.go index f0d8f40b2..97a97b241 100644 --- a/platform/view/sdk/dig/sdk.go +++ b/platform/view/sdk/dig/sdk.go @@ -107,6 +107,7 @@ func (p *SDK) Install() error { p.C.Provide(kvs.NewSignerKVS, dig.As(new(driver4.SignerKVS))), p.C.Provide(kvs.NewAuditInfoKVS, dig.As(new(driver4.AuditInfoKVS))), p.C.Provide(endpoint.NewService), + p.C.Provide(endpoint.NewBinder, dig.As(new(endpoint.Binder))), p.C.Provide(digutils.Identity[*endpoint.Service](), dig.As(new(driver.EndpointService))), p.C.Provide(view.NewEndpointService), p.C.Provide(digutils.Identity[*view.EndpointService](), dig.As(new(comm.EndpointService), new(id.EndpointService), new(endpoint.Backend))), diff --git a/platform/view/services/comm/host/rest/routing/idrouter.go b/platform/view/services/comm/host/rest/routing/idrouter.go index f2b9878c3..a6ebf89f2 100644 --- a/platform/view/services/comm/host/rest/routing/idrouter.go +++ b/platform/view/services/comm/host/rest/routing/idrouter.go @@ -18,7 +18,8 @@ import ( type endpointService interface { GetIdentity(endpoint string, pkID []byte) (view2.Identity, error) - Resolve(party view2.Identity) (string, view2.Identity, map[driver.PortName]string, []byte, error) + Resolve(party view2.Identity) (driver.Resolver, []byte, error) + GetResolver(party view2.Identity) (driver.Resolver, error) } // endpointServiceIDRouter resolves the IP addresses using the resolvers of the endpoint service. @@ -37,12 +38,12 @@ func (r *endpointServiceIDRouter) Lookup(id host2.PeerID) ([]host2.PeerIPAddress logger.Errorf("failed getting identity for peer [%s]", id) return []host2.PeerIPAddress{}, false } - _, _, addresses, _, err := r.es.Resolve(identity) + resolver, err := r.es.GetResolver(identity) if err != nil { logger.Errorf("failed resolving [%s]: %v", id, err) return []host2.PeerIPAddress{}, false } - if address, ok := addresses[driver.P2PPort]; ok { + if address := resolver.GetAddress(driver.P2PPort); len(address) > 0 { logger.Debugf("Found endpoint of peer [%s]: [%s]", id, address) return []host2.PeerIPAddress{address}, true } @@ -130,11 +131,11 @@ func (r *labelResolver) getLabel(peerID host2.PeerID) (string, error) { return "", errors.Wrapf(err, "failed to find identity for peer [%s]", peerID) } - label, _, _, pkid, err := r.es.Resolve(identity) + resolver, pkid, err := r.es.Resolve(identity) if pkid == nil && err != nil { - return "", errors.Wrapf(err, "failed to resolve identity [%s] for label [%s]", identity, label) + return "", errors.Wrapf(err, "failed to resolve identity [%s] for label [%s]", identity, resolver.GetName()) } - label = strings.TrimPrefix(label, "fsc.") + label := strings.TrimPrefix(resolver.GetName(), "fsc.") r.cache[peerID] = label return label, nil } diff --git a/platform/view/services/kvs/scope.go b/platform/view/services/kvs/scope.go index 366a3afb4..fac32f13b 100644 --- a/platform/view/services/kvs/scope.go +++ b/platform/view/services/kvs/scope.go @@ -62,6 +62,17 @@ type bindingKVS struct { e *enhancedKVS[view.Identity, view.Identity] } +func (kvs *bindingKVS) HaveSameBinding(this, that view.Identity) (bool, error) { + thisBinding, err := kvs.e.Get(this) + if err != nil { + return false, err + } + thatBinding, err := kvs.e.Get(that) + if err != nil { + return false, err + } + return thisBinding.Equal(thatBinding), nil +} func (kvs *bindingKVS) GetBinding(ephemeral view.Identity) (view.Identity, error) { return kvs.e.Get(ephemeral) }