Skip to content

Commit

Permalink
Custom cancelable implementation replaced with context
Browse files Browse the repository at this point in the history
  • Loading branch information
soffokl committed Sep 19, 2018
1 parent 44d95da commit 1a612f4
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 380 deletions.
5 changes: 4 additions & 1 deletion core/connection/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package connection

import (
"context"

"github.com/mysteriumnetwork/go-openvpn/openvpn"
"github.com/mysteriumnetwork/go-openvpn/openvpn/middlewares/state"

"github.com/mysteriumnetwork/node/communication"
"github.com/mysteriumnetwork/node/identity"
"github.com/mysteriumnetwork/node/service_discovery/dto"
Expand All @@ -36,7 +39,7 @@ type VpnClientCreator func(session.SessionDto, identity.Identity, identity.Ident
// Manager interface provides methods to manage connection
type Manager interface {
// Connect creates new connection from given consumer to provider, reports error if connection already exists
Connect(consumerID identity.Identity, providerID identity.Identity, options ConnectOptions) error
Connect(ctx context.Context, consumerID identity.Identity, providerID identity.Identity, options ConnectOptions) error
// Status queries current status of connection
Status() ConnectionStatus
// Disconnect closes established connection, reports error if no connection
Expand Down
96 changes: 34 additions & 62 deletions core/connection/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package connection

import (
"context"
"errors"
"sync"

Expand All @@ -30,7 +31,6 @@ import (
"github.com/mysteriumnetwork/node/server"
"github.com/mysteriumnetwork/node/service_discovery/dto"
"github.com/mysteriumnetwork/node/session"
"github.com/mysteriumnetwork/node/utils"
)

const managerLogPrefix = "[connection-manager] "
Expand All @@ -47,16 +47,16 @@ var (
)

type connectionManager struct {
ctx context.Context
//these are passed on creation
mysteriumClient server.Client
newDialog DialogCreator
newVpnClient VpnClientCreator
statsKeeper stats.SessionStatsKeeper
//these are populated by Connect at runtime
mutex sync.RWMutex
status ConnectionStatus
cleanConnection func()

mutex sync.RWMutex
}

// NewManager creates connection manager with given dependencies
Expand All @@ -72,12 +72,13 @@ func NewManager(mysteriumClient server.Client, dialogCreator DialogCreator,
}
}

func (manager *connectionManager) Connect(consumerID, providerID identity.Identity, options ConnectOptions) (err error) {
func (manager *connectionManager) Connect(ctx context.Context, consumerID, providerID identity.Identity, options ConnectOptions) (err error) {
if manager.status.State != NotConnected {
return ErrAlreadyExists
}

manager.mutex.Lock()
manager.ctx, manager.cleanConnection = context.WithCancel(ctx)
manager.status = statusConnecting()
manager.mutex.Unlock()
defer func() {
Expand All @@ -89,76 +90,57 @@ func (manager *connectionManager) Connect(consumerID, providerID identity.Identi
}()

err = manager.startConnection(consumerID, providerID, options)
if err == utils.ErrRequestCancelled {
if err == context.Canceled {
return ErrConnectionCancelled
}
return err
}

func (manager *connectionManager) startConnection(consumerID, providerID identity.Identity, options ConnectOptions) (err error) {
cancelable := utils.NewCancelable()

manager.mutex.Lock()
manager.cleanConnection = utils.CallOnce(func() {
log.Info(managerLogPrefix, "Cancelling connection initiation")
manager.status = statusDisconnecting()
cancelable.Cancel()
})
cancelCtx := manager.cleanConnection
manager.mutex.Unlock()

val, err := cancelable.
NewRequest(func() (interface{}, error) {
return manager.findProposalByProviderID(providerID)
}).
Call()
var cancel []func()
defer func() {
manager.cleanConnection = func() {
manager.status = statusDisconnecting()
cancelCtx()
for _, f := range cancel {
f()
}
}
if err != nil {
log.Info(managerLogPrefix, "Cancelling connection initiation")
defer manager.cleanConnection()
}
}()

proposal, err := manager.findProposalByProviderID(providerID)
if err != nil {
return err
}
proposal := val.(*dto.ServiceProposal)

val, err = cancelable.
NewRequest(func() (interface{}, error) {
return manager.newDialog(consumerID, providerID, proposal.ProviderContacts[0])
}).
Cleanup(utils.InvokeOnSuccess(func(val interface{}) {
val.(communication.Dialog).Close()
})).
Call()

dialog, err := manager.newDialog(consumerID, providerID, proposal.ProviderContacts[0])
if err != nil {
return err
}
dialog := val.(communication.Dialog)
cancel = append(cancel, func() { dialog.Close() })

val, err = cancelable.
NewRequest(func() (interface{}, error) {
return session.RequestSessionCreate(dialog, proposal.ID)
}).
Call()
vpnSession, err := session.RequestSessionCreate(dialog, proposal.ID)
if err != nil {
dialog.Close()
return err
}
vpnSession := val.(*session.SessionDto)

stateChannel := make(chan openvpn.State, 10)
val, err = cancelable.
NewRequest(func() (interface{}, error) {
return manager.startOpenvpnClient(*vpnSession, consumerID, providerID, stateChannel, options)
}).
Cleanup(utils.InvokeOnSuccess(func(val interface{}) {
val.(openvpn.Process).Stop()
})).
Call()
openvpnClient, err := manager.startOpenvpnClient(*vpnSession, consumerID, providerID, stateChannel, options)
if err != nil {
dialog.Close()
return err
}
openvpnClient := val.(openvpn.Process)
cancel = append(cancel, openvpnClient.Stop)

err = manager.waitForConnectedState(stateChannel, vpnSession.ID, cancelable.Cancelled)
err = manager.waitForConnectedState(stateChannel, vpnSession.ID)
if err != nil {
dialog.Close()
openvpnClient.Stop()
return err
}

Expand All @@ -168,15 +150,6 @@ func (manager *connectionManager) startConnection(consumerID, providerID identit
firewall.NewKillSwitch().Enable()
}

manager.mutex.Lock()
manager.cleanConnection = func() {
log.Info(managerLogPrefix, "Closing active connection")
manager.status = statusDisconnecting()
openvpnClient.Stop()
log.Info(managerLogPrefix, "Openvpn client stop requested")
}
manager.mutex.Unlock()

go openvpnClientWaiter(openvpnClient, dialog)
go manager.consumeOpenvpnStates(stateChannel, vpnSession.ID)
return nil
Expand Down Expand Up @@ -206,7 +179,7 @@ func warnOnClean() {

// TODO this can be extracted as dependency later when node selection criteria will be clear
func (manager *connectionManager) findProposalByProviderID(providerID identity.Identity) (*dto.ServiceProposal, error) {
proposals, err := manager.mysteriumClient.FindProposals(providerID.Address)
proposals, err := manager.mysteriumClient.FindProposals(manager.ctx, providerID.Address)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -246,8 +219,7 @@ func (manager *connectionManager) startOpenvpnClient(vpnSession session.SessionD
return openvpnClient, nil
}

func (manager *connectionManager) waitForConnectedState(stateChannel <-chan openvpn.State, sessionID session.SessionID, cancelRequest utils.CancelChannel) error {

func (manager *connectionManager) waitForConnectedState(stateChannel <-chan openvpn.State, sessionID session.SessionID) error {
for {
select {
case state, more := <-stateChannel:
Expand All @@ -262,8 +234,8 @@ func (manager *connectionManager) waitForConnectedState(stateChannel <-chan open
default:
manager.onStateChanged(state, sessionID)
}
case <-cancelRequest:
return utils.ErrRequestCancelled
case <-manager.ctx.Done():
return manager.ctx.Err()
}
}
}
Expand Down
35 changes: 18 additions & 17 deletions core/connection/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package connection

import (
"context"
"errors"
"sync"
"testing"
Expand Down Expand Up @@ -62,7 +63,7 @@ func (tc *testContext) SetupTest() {
defer tc.Unlock()

tc.fakeDiscoveryClient = server.NewClientFake()
tc.fakeDiscoveryClient.RegisterProposal(activeProposal, nil)
tc.fakeDiscoveryClient.RegisterProposal(context.Background(), activeProposal, nil)
tc.fakeDialog = &fakeDialog{}

dialogCreator := func(consumer, provider identity.Identity, contact dto.Contact) (communication.Dialog, error) {
Expand Down Expand Up @@ -115,7 +116,7 @@ func (tc *testContext) TestWhenNoConnectionIsMadeStatusIsNotConnected() {
func (tc *testContext) TestWithUnknownProviderConnectionIsNotMade() {
noProposalsError := errors.New("provider has no service proposals")

assert.Equal(tc.T(), noProposalsError, tc.connManager.Connect(myID, identity.FromAddress("unknown-node"), ConnectOptions{}))
assert.Equal(tc.T(), noProposalsError, tc.connManager.Connect(context.Background(), myID, identity.FromAddress("unknown-node"), ConnectOptions{}))
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())

assert.False(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
Expand All @@ -125,20 +126,20 @@ func (tc *testContext) TestOnConnectErrorStatusIsNotConnectedAndSessionStartIsNo
fatalVpnError := errors.New("fatal connection error")
tc.fakeOpenVpn.onStartReturnError = fatalVpnError

assert.Error(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Error(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())
assert.True(tc.T(), tc.fakeDialog.closed)
assert.False(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
}

func (tc *testContext) TestWhenManagerMadeConnectionStatusReturnsConnectedStateAndSessionId() {
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
}

func (tc *testContext) TestWhenManagerMadeConnectionSessionStartIsMarked() {
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)

assert.True(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
Expand All @@ -148,7 +149,7 @@ func (tc *testContext) TestStatusReportsConnectingWhenConnectionIsInProgress() {
tc.fakeOpenVpn.onStartReportStates = []openvpn.State{}

go func() {
tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
assert.Fail(tc.T(), "This should never return")
}()

Expand All @@ -159,7 +160,7 @@ func (tc *testContext) TestStatusReportsConnectingWhenConnectionIsInProgress() {

func (tc *testContext) TestStatusReportsDisconnectingThenNotConnected() {
tc.fakeOpenVpn.onStopReportStates = []openvpn.State{}
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())

Expand All @@ -174,23 +175,23 @@ func (tc *testContext) TestStatusReportsDisconnectingThenNotConnected() {
}

func (tc *testContext) TestConnectResultsInAlreadyConnectedErrorWhenConnectionExists() {
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), ErrAlreadyExists, tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), ErrAlreadyExists, tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
}

func (tc *testContext) TestDisconnectReturnsErrorWhenNoConnectionExists() {
assert.Equal(tc.T(), ErrNoConnection, tc.connManager.Disconnect())
}

func (tc *testContext) TestReconnectingStatusIsReportedWhenOpenVpnGoesIntoReconnectingState() {
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
tc.fakeOpenVpn.reportState(openvpn.ReconnectingState)
waitABit()
assert.Equal(tc.T(), statusReconnecting(), tc.connManager.Status())
}

func (tc *testContext) TestDoubleDisconnectResultsInError() {
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
Expand All @@ -199,13 +200,13 @@ func (tc *testContext) TestDoubleDisconnectResultsInError() {
}

func (tc *testContext) TestTwoConnectDisconnectCyclesReturnNoError() {
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())

assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
Expand All @@ -215,11 +216,11 @@ func (tc *testContext) TestTwoConnectDisconnectCyclesReturnNoError() {

func (tc *testContext) TestConnectFailsIfOpenvpnFactoryReturnsError() {
tc.openvpnCreationError = errors.New("failed to create vpn instance")
assert.Error(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Error(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
}

func (tc *testContext) TestStatusIsConnectedWhenConnectCommandReturnsWithoutError() {
tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
}

Expand All @@ -230,7 +231,7 @@ func (tc *testContext) TestConnectingInProgressCanBeCanceled() {
var err error
go func() {
defer connectWaiter.Done()
err = tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
err = tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
}()

waitABit()
Expand All @@ -251,7 +252,7 @@ func (tc *testContext) TestConnectMethodReturnsErrorIfOpenvpnClientExitsDuringCo
var err error
go func() {
defer connectWaiter.Done()
err = tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
err = tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
}()
waitABit()
tc.fakeOpenVpn.reportState(openvpn.ProcessExited)
Expand Down
Loading

0 comments on commit 1a612f4

Please sign in to comment.