From 100a85aa5a76dc3f2f54c267d7fbdf6667a66b49 Mon Sep 17 00:00:00 2001 From: Dmitry Shihovtsev Date: Thu, 13 Dec 2018 21:29:05 +0600 Subject: [PATCH] Added destroy callback for transport services --- services/noop/service.go | 4 ++-- services/noop/service_test.go | 2 +- services/openvpn/service/factory.go | 4 ++-- services/wireguard/service/service.go | 23 +++++++++------------- services/wireguard/service/service_test.go | 2 +- session/create_consumer.go | 6 +++--- session/create_consumer_test.go | 6 +++--- session/dto.go | 7 ++++--- session/manager.go | 14 ++++++++----- session/manager_test.go | 6 +++--- 10 files changed, 37 insertions(+), 37 deletions(-) diff --git a/services/noop/service.go b/services/noop/service.go index 89660c61cc..55deded906 100644 --- a/services/noop/service.go +++ b/services/noop/service.go @@ -55,8 +55,8 @@ type Manager struct { type negotiator struct { } -func (n *negotiator) ProvideConfig(cfg json.RawMessage) (session.ServiceConfiguration, error) { - return nil, nil +func (n *negotiator) ProvideConfig(cfg json.RawMessage) (session.ServiceConfiguration, session.DestroyCallback, error) { + return nil, nil, nil } // Start starts service - does not block diff --git a/services/noop/service_test.go b/services/noop/service_test.go index 2ded94ae6d..6dbedb206d 100644 --- a/services/noop/service_test.go +++ b/services/noop/service_test.go @@ -67,7 +67,7 @@ func Test_Manager_Start(t *testing.T) { proposal, ) - sessionConfig, err := sessionConfigProvider.ProvideConfig(nil) + sessionConfig, _, err := sessionConfigProvider.ProvideConfig(nil) assert.NoError(t, err) assert.Nil(t, sessionConfig) } diff --git a/services/openvpn/service/factory.go b/services/openvpn/service/factory.go index 74f0eeb61f..82d5b1d779 100644 --- a/services/openvpn/service/factory.go +++ b/services/openvpn/service/factory.go @@ -113,8 +113,8 @@ type OpenvpnConfigNegotiator struct { } // ProvideConfig returns the config for user -func (ocn *OpenvpnConfigNegotiator) ProvideConfig(json.RawMessage) (session.ServiceConfiguration, error) { - return &ocn.vpnConfig, nil +func (ocn *OpenvpnConfigNegotiator) ProvideConfig(json.RawMessage) (session.ServiceConfiguration, session.DestroyCallback, error) { + return &ocn.vpnConfig, nil, nil } func vpnServerIP(serviceOptions Options, outboundIP, publicIP string) string { diff --git a/services/wireguard/service/service.go b/services/wireguard/service/service.go index dad781277e..c50aaecdd0 100644 --- a/services/wireguard/service/service.go +++ b/services/wireguard/service/service.go @@ -61,34 +61,34 @@ type Manager struct { } // ProvideConfig provides the config for consumer -func (manager *Manager) ProvideConfig(publicKey json.RawMessage) (session.ServiceConfiguration, error) { +func (manager *Manager) ProvideConfig(publicKey json.RawMessage) (session.ServiceConfiguration, session.DestroyCallback, error) { key := &wg.ConsumerConfig{} err := json.Unmarshal(publicKey, key) if err != nil { - return nil, err + return nil, nil, err } connectionEndpoint, err := manager.connectionEndpointFactory() if err != nil { - return nil, err + return nil, nil, err } if err := connectionEndpoint.Start(nil); err != nil { - return nil, err + return nil, nil, err } if err := connectionEndpoint.AddPeer(key.PublicKey, nil); err != nil { - return nil, err + return nil, nil, err } config, err := connectionEndpoint.Config() if err != nil { - return nil, err + return nil, nil, err } outboundIP, err := manager.ipResolver.GetOutboundIP() if err != nil { - return nil, err + return nil, nil, err } manager.natService.Add(nat.RuleForwarding{ @@ -96,15 +96,10 @@ func (manager *Manager) ProvideConfig(publicKey json.RawMessage) (session.Servic TargetIP: outboundIP, }) if err := manager.natService.Start(); err != nil { - return nil, err + return nil, nil, err } - return config, nil -} - -// ConsumeConfig takes in the provided config and adds it to the wg device -func (manager *Manager) ConsumeConfig() error { - return nil + return config, connectionEndpoint.Stop, nil } // Start starts service - does not block diff --git a/services/wireguard/service/service_test.go b/services/wireguard/service/service_test.go index fca52cf7ad..f390718159 100644 --- a/services/wireguard/service/service_test.go +++ b/services/wireguard/service/service_test.go @@ -72,7 +72,7 @@ func Test_Manager_Start(t *testing.T) { }, proposal, ) - sessionConfig, err := sessionConfigProvider.ProvideConfig(json.RawMessage(`{"PublicKey": "gZfkZArbw9lqfl4Yzr1Kv3nqGlhe/ynH9KKRbzPFMGk="}`)) + sessionConfig, _, err := sessionConfigProvider.ProvideConfig(json.RawMessage(`{"PublicKey": "gZfkZArbw9lqfl4Yzr1Kv3nqGlhe/ynH9KKRbzPFMGk="}`)) assert.NoError(t, err) assert.NotNil(t, sessionConfig) } diff --git a/session/create_consumer.go b/session/create_consumer.go index ce2eda93a1..89d553413b 100644 --- a/session/create_consumer.go +++ b/session/create_consumer.go @@ -33,7 +33,7 @@ type createConsumer struct { // Creator defines method for session creation type Creator interface { - Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration) (Session, error) + Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration, destroyCallback DestroyCallback) (Session, error) } // GetMessageEndpoint returns endpoint there to receive messages @@ -51,12 +51,12 @@ func (consumer *createConsumer) NewRequest() (requestPtr interface{}) { func (consumer *createConsumer) Consume(requestPtr interface{}) (response interface{}, err error) { request := requestPtr.(*CreateRequest) - config, err := consumer.configProvider(request.Config) + config, destroyCallback, err := consumer.configProvider(request.Config) if err != nil { return responseInternalError, err } - sessionInstance, err := consumer.sessionCreator.Create(consumer.peerID, request.ProposalId, config) + sessionInstance, err := consumer.sessionCreator.Create(consumer.peerID, request.ProposalId, config, destroyCallback) switch err { case nil: return responseWithSession(sessionInstance), nil diff --git a/session/create_consumer_test.go b/session/create_consumer_test.go index 67f960c288..fc22522ea7 100644 --- a/session/create_consumer_test.go +++ b/session/create_consumer_test.go @@ -27,8 +27,8 @@ import ( ) var ( - mockConsumer = func(json.RawMessage) (ServiceConfiguration, error) { - return nil, nil + mockConsumer = func(json.RawMessage) (ServiceConfiguration, DestroyCallback, error) { + return nil, nil, nil } ) @@ -107,7 +107,7 @@ type managerFake struct { } // Create function creates and returns fake session -func (manager *managerFake) Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration) (Session, error) { +func (manager *managerFake) Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration, destroyCallback DestroyCallback) (Session, error) { manager.lastConsumerID = consumerID manager.lastProposalID = proposalID return manager.returnSession, manager.returnError diff --git a/session/dto.go b/session/dto.go index 6cc8f6d951..5e6e525cc5 100644 --- a/session/dto.go +++ b/session/dto.go @@ -24,9 +24,10 @@ type ID string // Session structure holds all required information about current session between service consumer and provider type Session struct { - ID ID - Config ServiceConfiguration - ConsumerID identity.Identity + ID ID + Config ServiceConfiguration + ConsumerID identity.Identity + DestroyCallback DestroyCallback } // ServiceConfiguration defines service configuration from underlying transport mechanism to be passed to remote party diff --git a/session/manager.go b/session/manager.go index 6269a1cb78..4fc9483ff6 100644 --- a/session/manager.go +++ b/session/manager.go @@ -40,14 +40,14 @@ type IDGenerator func() (ID, error) // ConfigNegotiator is able to handle config negotiations type ConfigNegotiator interface { - ProvideConfig(consumerKey json.RawMessage) (ServiceConfiguration, error) + ProvideConfig(consumerKey json.RawMessage) (ServiceConfiguration, DestroyCallback, error) } // ConfigProvider provides session config for remote client -type ConfigProvider func(consumerKey json.RawMessage) (ServiceConfiguration, error) +type ConfigProvider func(consumerKey json.RawMessage) (ServiceConfiguration, DestroyCallback, error) -// SaveCallback stores newly started sessions -type SaveCallback func(Session) +// DestroyCallback cleanups session +type DestroyCallback func() error // PromiseProcessor processes promises at provider side. // Provider checks promises from consumer and signs them also. @@ -93,7 +93,7 @@ type Manager struct { } // Create creates session instance. Multiple sessions per peerID is possible in case different services are used -func (manager *Manager) Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration) (sessionInstance Session, err error) { +func (manager *Manager) Create(consumerID identity.Identity, proposalID int, config ServiceConfiguration, destroyCallback DestroyCallback) (sessionInstance Session, err error) { manager.creationLock.Lock() defer manager.creationLock.Unlock() @@ -112,6 +112,7 @@ func (manager *Manager) Create(consumerID identity.Identity, proposalID int, con return } + sessionInstance.DestroyCallback = destroyCallback manager.sessionStorage.Add(sessionInstance) return sessionInstance, nil } @@ -138,6 +139,9 @@ func (manager *Manager) Destroy(consumerID identity.Identity, sessionID string) manager.sessionStorage.Remove(ID(sessionID)) + if sessionInstance.DestroyCallback != nil { + return sessionInstance.DestroyCallback() + } return nil } diff --git a/session/manager_test.go b/session/manager_test.go index b416b4f916..d957b8a685 100644 --- a/session/manager_test.go +++ b/session/manager_test.go @@ -65,7 +65,7 @@ func TestManager_Create_StoresSession(t *testing.T) { sessionStore := NewStorageMemory() manager := NewManager(currentProposal, generateSessionID, sessionStore, &fakePromiseProcessor{}) - sessionInstance, err := manager.Create(identity.FromAddress("deadbeef"), currentProposalID, expectedSessionConfig) + sessionInstance, err := manager.Create(identity.FromAddress("deadbeef"), currentProposalID, expectedSessionConfig, nil) assert.NoError(t, err) assert.Exactly(t, expectedSession, sessionInstance) } @@ -74,7 +74,7 @@ func TestManager_Create_RejectsUnknownProposal(t *testing.T) { sessionStore := NewStorageMemory() manager := NewManager(currentProposal, generateSessionID, sessionStore, &fakePromiseProcessor{}) - sessionInstance, err := manager.Create(identity.FromAddress("deadbeef"), 69, expectedSessionConfig) + sessionInstance, err := manager.Create(identity.FromAddress("deadbeef"), 69, expectedSessionConfig, nil) assert.Exactly(t, err, ErrorInvalidProposal) assert.Exactly(t, Session{}, sessionInstance) } @@ -84,7 +84,7 @@ func TestManager_Create_StartsPromiseProcessor(t *testing.T) { sessionStore := NewStorageMemory() manager := NewManager(currentProposal, generateSessionID, sessionStore, promiseProcessor) - _, err := manager.Create(identity.FromAddress("deadbeef"), currentProposalID, expectedSessionConfig) + _, err := manager.Create(identity.FromAddress("deadbeef"), currentProposalID, expectedSessionConfig, nil) assert.NoError(t, err) assert.True(t, promiseProcessor.started) assert.Exactly(t, currentProposal, promiseProcessor.proposal)