Skip to content

Commit

Permalink
Context spreading reduced
Browse files Browse the repository at this point in the history
  • Loading branch information
soffokl committed Sep 19, 2018
1 parent 7e9dcd2 commit 1ea4751
Show file tree
Hide file tree
Showing 17 changed files with 81 additions and 93 deletions.
4 changes: 1 addition & 3 deletions core/connection/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package connection

import (
"context"

"github.com/mysteriumnetwork/go-openvpn/openvpn"
"github.com/mysteriumnetwork/go-openvpn/openvpn/middlewares/state"

Expand All @@ -39,7 +37,7 @@ type VpnClientCreator func(session.SessionDto, identity.Identity, identity.Ident
// Manager interface provides methods to manage connection
type Manager interface {
// Connect creates new connection from given consumer to provider, reports error if connection already exists
Connect(ctx context.Context, consumerID identity.Identity, providerID identity.Identity, options ConnectOptions) error
Connect(consumerID identity.Identity, providerID identity.Identity, options ConnectOptions) error
// Status queries current status of connection
Status() ConnectionStatus
// Disconnect closes established connection, reports error if no connection
Expand Down
6 changes: 3 additions & 3 deletions core/connection/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ func NewManager(mysteriumClient server.Client, dialogCreator DialogCreator,
}
}

func (manager *connectionManager) Connect(ctx context.Context, consumerID, providerID identity.Identity, options ConnectOptions) (err error) {
func (manager *connectionManager) Connect(consumerID, providerID identity.Identity, options ConnectOptions) (err error) {
if manager.status.State != NotConnected {
return ErrAlreadyExists
}

manager.mutex.Lock()
manager.ctx, manager.cleanConnection = context.WithCancel(ctx)
manager.ctx, manager.cleanConnection = context.WithCancel(context.Background())
manager.status = statusConnecting()
manager.mutex.Unlock()
defer func() {
Expand Down Expand Up @@ -179,7 +179,7 @@ func warnOnClean() {

// TODO this can be extracted as dependency later when node selection criteria will be clear
func (manager *connectionManager) findProposalByProviderID(providerID identity.Identity) (*dto.ServiceProposal, error) {
proposals, err := manager.mysteriumClient.FindProposals(manager.ctx, providerID.Address)
proposals, err := manager.mysteriumClient.FindProposals(providerID.Address)
if err != nil {
return nil, err
}
Expand Down
35 changes: 17 additions & 18 deletions core/connection/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package connection

import (
"context"
"errors"
"sync"
"testing"
Expand Down Expand Up @@ -63,7 +62,7 @@ func (tc *testContext) SetupTest() {
defer tc.Unlock()

tc.fakeDiscoveryClient = server.NewClientFake()
tc.fakeDiscoveryClient.RegisterProposal(context.Background(), activeProposal, nil)
tc.fakeDiscoveryClient.RegisterProposal(activeProposal, nil)
tc.fakeDialog = &fakeDialog{}

dialogCreator := func(consumer, provider identity.Identity, contact dto.Contact) (communication.Dialog, error) {
Expand Down Expand Up @@ -116,7 +115,7 @@ func (tc *testContext) TestWhenNoConnectionIsMadeStatusIsNotConnected() {
func (tc *testContext) TestWithUnknownProviderConnectionIsNotMade() {
noProposalsError := errors.New("provider has no service proposals")

assert.Equal(tc.T(), noProposalsError, tc.connManager.Connect(context.Background(), myID, identity.FromAddress("unknown-node"), ConnectOptions{}))
assert.Equal(tc.T(), noProposalsError, tc.connManager.Connect(myID, identity.FromAddress("unknown-node"), ConnectOptions{}))
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())

assert.False(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
Expand All @@ -126,20 +125,20 @@ func (tc *testContext) TestOnConnectErrorStatusIsNotConnectedAndSessionStartIsNo
fatalVpnError := errors.New("fatal connection error")
tc.fakeOpenVpn.onStartReturnError = fatalVpnError

assert.Error(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Error(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())
assert.True(tc.T(), tc.fakeDialog.closed)
assert.False(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
}

func (tc *testContext) TestWhenManagerMadeConnectionStatusReturnsConnectedStateAndSessionId() {
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
}

func (tc *testContext) TestWhenManagerMadeConnectionSessionStartIsMarked() {
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)

assert.True(tc.T(), tc.fakeStatsKeeper.sessionStartMarked)
Expand All @@ -149,7 +148,7 @@ func (tc *testContext) TestStatusReportsConnectingWhenConnectionIsInProgress() {
tc.fakeOpenVpn.onStartReportStates = []openvpn.State{}

go func() {
tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
assert.Fail(tc.T(), "This should never return")
}()

Expand All @@ -160,7 +159,7 @@ func (tc *testContext) TestStatusReportsConnectingWhenConnectionIsInProgress() {

func (tc *testContext) TestStatusReportsDisconnectingThenNotConnected() {
tc.fakeOpenVpn.onStopReportStates = []openvpn.State{}
err := tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
err := tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
assert.NoError(tc.T(), err)
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())

Expand All @@ -175,23 +174,23 @@ func (tc *testContext) TestStatusReportsDisconnectingThenNotConnected() {
}

func (tc *testContext) TestConnectResultsInAlreadyConnectedErrorWhenConnectionExists() {
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), ErrAlreadyExists, tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), ErrAlreadyExists, tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
}

func (tc *testContext) TestDisconnectReturnsErrorWhenNoConnectionExists() {
assert.Equal(tc.T(), ErrNoConnection, tc.connManager.Disconnect())
}

func (tc *testContext) TestReconnectingStatusIsReportedWhenOpenVpnGoesIntoReconnectingState() {
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
tc.fakeOpenVpn.reportState(openvpn.ReconnectingState)
waitABit()
assert.Equal(tc.T(), statusReconnecting(), tc.connManager.Status())
}

func (tc *testContext) TestDoubleDisconnectResultsInError() {
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
Expand All @@ -200,13 +199,13 @@ func (tc *testContext) TestDoubleDisconnectResultsInError() {
}

func (tc *testContext) TestTwoConnectDisconnectCyclesReturnNoError() {
assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
assert.Equal(tc.T(), statusNotConnected(), tc.connManager.Status())

assert.NoError(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.NoError(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
assert.NoError(tc.T(), tc.connManager.Disconnect())
waitABit()
Expand All @@ -216,11 +215,11 @@ func (tc *testContext) TestTwoConnectDisconnectCyclesReturnNoError() {

func (tc *testContext) TestConnectFailsIfOpenvpnFactoryReturnsError() {
tc.openvpnCreationError = errors.New("failed to create vpn instance")
assert.Error(tc.T(), tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{}))
assert.Error(tc.T(), tc.connManager.Connect(myID, activeProviderID, ConnectOptions{}))
}

func (tc *testContext) TestStatusIsConnectedWhenConnectCommandReturnsWithoutError() {
tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
assert.Equal(tc.T(), statusConnected("vpn-connection-id"), tc.connManager.Status())
}

Expand All @@ -231,7 +230,7 @@ func (tc *testContext) TestConnectingInProgressCanBeCanceled() {
var err error
go func() {
defer connectWaiter.Done()
err = tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
err = tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
}()

waitABit()
Expand All @@ -252,7 +251,7 @@ func (tc *testContext) TestConnectMethodReturnsErrorIfOpenvpnClientExitsDuringCo
var err error
go func() {
defer connectWaiter.Done()
err = tc.connManager.Connect(context.Background(), myID, activeProviderID, ConnectOptions{})
err = tc.connManager.Connect(myID, activeProviderID, ConnectOptions{})
}()
waitABit()
tc.fakeOpenVpn.reportState(openvpn.ProcessExited)
Expand Down
33 changes: 18 additions & 15 deletions discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package discovery

import (
"context"
"time"

log "github.com/cihub/seelog"
Expand Down Expand Up @@ -54,14 +53,17 @@ func (d *Discovery) Start(ownIdentity identity.Identity, proposal dto_discovery.
d.signer = d.signerCreate(ownIdentity)
d.proposal = proposal

ctx, cancel := context.WithCancel(context.Background())
d.stop = cancel
stopLoop := make(chan bool)
d.stop = func() {
// cancel (stop) discovery loop
stopLoop <- true
}

d.proposalAnnouncementStopped.Add(1)

go d.checkRegistration()

go d.mainDiscoveryLoop(ctx)
go d.mainDiscoveryLoop(stopLoop)
}

// Wait wait for proposal announcements to stop / unregister
Expand All @@ -74,21 +76,22 @@ func (d *Discovery) Stop() {
d.stop()
}

func (d *Discovery) mainDiscoveryLoop(ctx context.Context) {
func (d *Discovery) mainDiscoveryLoop(stopLoop chan bool) {

for {
select {
case <-ctx.Done():
case <-stopLoop:
d.stopLoop()
case event := <-d.statusChan:
switch event {
case IdentityUnregistered:
d.registerIdentity()
case RegisterProposal:
go d.registerProposal(ctx)
go d.registerProposal()
case PingProposal:
go d.pingProposal(ctx)
go d.pingProposal()
case UnregisterProposal:
go d.unregisterProposal(ctx)
go d.unregisterProposal()
case IdentityRegisterFailed, ProposalUnregistered, UnregisterProposalFailed:
d.proposalAnnouncementStopped.Done()
return
Expand Down Expand Up @@ -131,8 +134,8 @@ func (d *Discovery) registerIdentity() {
}()
}

func (d *Discovery) registerProposal(ctx context.Context) {
err := d.mysteriumClient.RegisterProposal(ctx, d.proposal, d.signer)
func (d *Discovery) registerProposal() {
err := d.mysteriumClient.RegisterProposal(d.proposal, d.signer)
if err != nil {
log.Errorf("%s Failed to register proposal, retrying after 1 min. %s", logPrefix, err.Error())
time.Sleep(1 * time.Minute)
Expand All @@ -142,17 +145,17 @@ func (d *Discovery) registerProposal(ctx context.Context) {
d.changeStatus(PingProposal)
}

func (d *Discovery) pingProposal(ctx context.Context) {
func (d *Discovery) pingProposal() {
time.Sleep(1 * time.Minute)
err := d.mysteriumClient.PingProposal(ctx, d.proposal, d.signer)
err := d.mysteriumClient.PingProposal(d.proposal, d.signer)
if err != nil {
log.Error(logPrefix, "Failed to ping proposal: ", err)
}
d.changeStatus(PingProposal)
}

func (d *Discovery) unregisterProposal(ctx context.Context) {
err := d.mysteriumClient.UnregisterProposal(ctx, d.proposal, d.signer)
func (d *Discovery) unregisterProposal() {
err := d.mysteriumClient.UnregisterProposal(d.proposal, d.signer)
if err != nil {
log.Error(logPrefix, "Failed to unregister proposal: ", err)
d.changeStatus(UnregisterProposalFailed)
Expand Down
3 changes: 1 addition & 2 deletions identity/selector/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package selector

import (
"context"
"errors"

"github.com/mysteriumnetwork/node/identity"
Expand Down Expand Up @@ -85,7 +84,7 @@ func (h *handler) UseNew(passphrase string) (id identity.Identity, err error) {
return
}

if err = h.identityAPI.RegisterIdentity(context.Background(), id, h.signerFactory(id)); err != nil {
if err = h.identityAPI.RegisterIdentity(id, h.signerFactory(id)); err != nil {
return
}

Expand Down
17 changes: 8 additions & 9 deletions requests/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package requests

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -35,22 +34,22 @@ const (
)

// NewGetRequest generates http Get request
func NewGetRequest(ctx context.Context, apiURI, path string, params url.Values) (*http.Request, error) {
func NewGetRequest(apiURI, path string, params url.Values) (*http.Request, error) {
pathWithQuery := fmt.Sprintf("%v?%v", path, params.Encode())
return newRequest(ctx, http.MethodGet, apiURI, pathWithQuery, nil)
return newRequest(http.MethodGet, apiURI, pathWithQuery, nil)
}

// NewPostRequest generates http Post request
func NewPostRequest(ctx context.Context, apiURI, path string, requestBody interface{}) (*http.Request, error) {
func NewPostRequest(apiURI, path string, requestBody interface{}) (*http.Request, error) {
bodyBytes, err := encodeToJSON(requestBody)
if err != nil {
return nil, err
}
return newRequest(ctx, http.MethodPost, apiURI, path, bodyBytes)
return newRequest(http.MethodPost, apiURI, path, bodyBytes)
}

// NewSignedPostRequest signs payload and generates http Post request
func NewSignedPostRequest(ctx context.Context, apiURI, path string, requestBody interface{}, signer identity.Signer) (*http.Request, error) {
func NewSignedPostRequest(apiURI, path string, requestBody interface{}, signer identity.Signer) (*http.Request, error) {
bodyBytes, err := encodeToJSON(requestBody)
if err != nil {
return nil, err
Expand All @@ -61,7 +60,7 @@ func NewSignedPostRequest(ctx context.Context, apiURI, path string, requestBody
return nil, err
}

req, err := newRequest(ctx, http.MethodPost, apiURI, path, bodyBytes)
req, err := newRequest(http.MethodPost, apiURI, path, bodyBytes)
if err != nil {
return nil, err
}
Expand All @@ -75,7 +74,7 @@ func encodeToJSON(value interface{}) ([]byte, error) {
return json.Marshal(value)
}

func newRequest(ctx context.Context, method, apiURI, path string, body []byte) (*http.Request, error) {
func newRequest(method, apiURI, path string, body []byte) (*http.Request, error) {
fullUrl := fmt.Sprintf("%v/%v", apiURI, path)
req, err := http.NewRequest(method, fullUrl, bytes.NewBuffer(body))
if err != nil {
Expand All @@ -85,5 +84,5 @@ func newRequest(ctx context.Context, method, apiURI, path string, body []byte) (
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")

return req.WithContext(ctx), nil
return req, nil
}
7 changes: 3 additions & 4 deletions requests/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package requests

import (
"bytes"
"context"
"io"
"net/url"
"testing"
Expand All @@ -46,7 +45,7 @@ func TestSignatureIsInsertedForSignedPost(t *testing.T) {

signer := mockedSigner{identity.SignatureBase64("deadbeef")}

req, err := NewSignedPostRequest(context.Background(), testRequestApiUrl, "/post-path", testPayload{"abc"}, &signer)
req, err := NewSignedPostRequest(testRequestApiUrl, "/post-path", testPayload{"abc"}, &signer)
assert.NoError(t, err)
assert.Equal(t, req.Header.Get("Authorization"), "Signature deadbeef")
}
Expand All @@ -57,7 +56,7 @@ func TestDoGetContactsPassedValuesForUrl(t *testing.T) {
params["param1"] = []string{"value1"}
params["param2"] = []string{"value2"}

req, err := NewGetRequest(context.Background(), testRequestApiUrl, "get-path", params)
req, err := NewGetRequest(testRequestApiUrl, "get-path", params)

assert.NoError(t, err)
assert.Equal(t, "http://testUrl/get-path?param1=value1&param2=value2", req.URL.String())
Expand All @@ -66,7 +65,7 @@ func TestDoGetContactsPassedValuesForUrl(t *testing.T) {

func TestPayloadIsSerializedSuccessfullyForPostMethod(t *testing.T) {

req, err := NewPostRequest(context.Background(), testRequestApiUrl, "post-path", testPayload{"abc"})
req, err := NewPostRequest(testRequestApiUrl, "post-path", testPayload{"abc"})

assert.NoError(t, err)

Expand Down
Loading

0 comments on commit 1ea4751

Please sign in to comment.