Skip to content

Commit

Permalink
Merge pull request #1727 from mysteriumnetwork/nat-with-stun
Browse files Browse the repository at this point in the history
TTL based NAT hole punching with parallel pings
  • Loading branch information
soffokl authored Feb 22, 2020
2 parents a9f8225 + 65bb881 commit 4b20b17
Show file tree
Hide file tree
Showing 29 changed files with 453 additions and 346 deletions.
3 changes: 1 addition & 2 deletions cmd/di.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func newSessionManagerFactory(
sessionStorage *session.EventBasedStorage,
providerInvoiceStorage *pingpong.ProviderInvoiceStorage,
accountantPromiseStorage *pingpong.AccountantPromiseStorage,
natPingerChan func(*traversal.Params),
natPingerChan func(traversal.Params),
natTracker *event.Tracker,
serviceID string,
eventbus eventbus.EventBus,
Expand Down Expand Up @@ -811,7 +811,6 @@ func (di *Dependencies) bootstrapNATComponents(options node.Options) {
log.Debug().Msg("Experimental NAT punching enabled, creating a pinger")
di.NATPinger = traversal.NewPinger(
traversal.DefaultPingConfig(),
traversal.NewNATProxy(),
di.EventBus,
)
} else {
Expand Down
8 changes: 8 additions & 0 deletions config/flags_network.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ var (
Usage: "Enables NAT hole punching",
Value: true,
}
// FlagNATPunchingMaxTTL sets max number of devices to try pass for NAT hole punching.
FlagNATPunchingMaxTTL = cli.IntFlag{
Name: "natpunching.max-ttl",
Usage: "Max number of devices to try pass for NAT hole punching",
Value: 10,
}
)

// RegisterFlagsNetwork function register network flags to flag list
Expand All @@ -66,6 +72,7 @@ func RegisterFlagsNetwork(flags *[]cli.Flag) {
&FlagTestnet,
&FlagLocalnet,
&FlagNATPunching,
&FlagNATPunchingMaxTTL,
&FlagAPIAddress,
&FlagBrokerAddress,
&FlagEtherRPC,
Expand All @@ -80,4 +87,5 @@ func ParseFlagsNetwork(ctx *cli.Context) {
Current.ParseStringFlag(ctx, FlagBrokerAddress)
Current.ParseStringFlag(ctx, FlagEtherRPC)
Current.ParseBoolFlag(ctx, FlagNATPunching)
Current.ParseIntFlag(ctx, FlagNATPunchingMaxTTL)
}
15 changes: 15 additions & 0 deletions core/port/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Pool struct {
// ServicePortSupplier provides port needed to run a service on
type ServicePortSupplier interface {
Acquire() (Port, error)
AcquireMultiple(n int) (ports []Port, err error)
}

// NewPool creates a port pool that will provide ports from range 40000-50000
Expand Down Expand Up @@ -82,3 +83,17 @@ func (pool *Pool) seekAvailablePort() (int, error) {
}
return 0, errors.New("port pool is exhausted")
}

// AcquireMultiple returns n unused ports from pool's range.
func (pool *Pool) AcquireMultiple(n int) (ports []Port, err error) {
for i := 0; i < n; i++ {
p, err := pool.Acquire()
if err != nil {
return ports, err
}

ports = append(ports, p)
}

return ports, nil
}
14 changes: 14 additions & 0 deletions core/port/pool_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,17 @@ func (pool *PoolFixed) Acquire() (port Port, err error) {
pool.port = port
return
}

// AcquireMultiple returns n unused ports from pool's range.
func (pool *PoolFixed) AcquireMultiple(n int) (ports []Port, err error) {
for i := 0; i < n; i++ {
p, err := pool.Acquire()
if err != nil {
return ports, err
}

ports = append(ports, p)
}

return ports, nil
}
4 changes: 2 additions & 2 deletions core/port/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ func TestAcquiredPortsAreUsable(t *testing.T) {
pool := NewPool()

port, _ := pool.Acquire()
err := listenUdp(port.Num())
err := listenUDP(port.Num())

assert.NoError(t, err)
}

func listenUdp(port int) error {
func listenUDP(port int) error {
udpAddr, err := net.ResolveUDPAddr("udp", ":"+strconv.Itoa(port))
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion core/service/stub_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (service *serviceFake) GetType() string {
}

func (service *serviceFake) ProvideConfig(_ string, _ json.RawMessage) (*session.ConfigParams, error) {
return &session.ConfigParams{TraversalParams: &traversal.Params{}}, nil
return &session.ConfigParams{TraversalParams: traversal.Params{}}, nil
}

type mockDialogWaiter struct {
Expand Down
74 changes: 41 additions & 33 deletions mobile/mysterium/openvpn_connection_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import (
"time"

"github.com/mysteriumnetwork/go-openvpn/openvpn3"
"github.com/mysteriumnetwork/node/config"
"github.com/mysteriumnetwork/node/core/connection"
"github.com/mysteriumnetwork/node/core/ip"
"github.com/mysteriumnetwork/node/core/port"
"github.com/mysteriumnetwork/node/identity"
"github.com/mysteriumnetwork/node/nat/traversal"
"github.com/mysteriumnetwork/node/services/openvpn"
Expand All @@ -38,7 +40,7 @@ type natPinger interface {
SetProtectSocketCallback(SocketProtect func(socket int) bool)
}

type openvpn3SessionFactory func(connection.ConnectOptions) (*openvpn3.Session, *openvpn.ClientConfig, error)
type openvpn3SessionFactory func(connection.ConnectOptions, openvpn.VPNConfig) (*openvpn3.Session, *openvpn.ClientConfig, error)

var errSessionWrapperNotStarted = errors.New("session wrapper not started")

Expand All @@ -48,25 +50,9 @@ func NewOpenVPNConnection(sessionTracker *sessionTracker, signerFactory identity
stateCh: make(chan connection.State, 100),
natPinger: natPinger,
ipResolver: ipResolver,
pingerStop: make(chan struct{}),
}
sessionFactory := func(options connection.ConnectOptions) (*openvpn3.Session, *openvpn.ClientConfig, error) {
sessionConfig := &openvpn.VPNConfig{}
err := json.Unmarshal(options.SessionConfig, sessionConfig)
if err != nil {
return nil, nil, err
}

// override vpnClientConfig params with proxy local IP and pinger port
// do this only if connecting to natted provider
if sessionConfig.LocalPort > 0 {
sessionConfig.OriginalRemoteIP = sessionConfig.RemoteIP
sessionConfig.OriginalRemotePort = sessionConfig.RemotePort
sessionConfig.RemoteIP = "127.0.0.1"
// TODO: randomize this too?
sessionConfig.RemotePort = sessionConfig.LocalPort + 1
}

sessionFactory := func(options connection.ConnectOptions, sessionConfig openvpn.VPNConfig) (*openvpn3.Session, *openvpn.ClientConfig, error) {
vpnClientConfig, err := openvpn.NewClientConfigFromSession(sessionConfig, "", "", connection.DNSOptionAuto)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -106,7 +92,7 @@ func NewOpenVPNConnection(sessionTracker *sessionTracker, signerFactory identity
}

type openvpnConnection struct {
pingerStop chan struct{}
ports []int
stateCh chan connection.State
stats connection.Statistics
statsMu sync.RWMutex
Expand Down Expand Up @@ -159,25 +145,39 @@ func (c *openvpnConnection) Statistics() (connection.Statistics, error) {
}

func (c *openvpnConnection) Start(options connection.ConnectOptions) error {
newSession, clientConfig, err := c.createSession(options)
sessionConfig := openvpn.VPNConfig{}
err := json.Unmarshal(options.SessionConfig, &sessionConfig)
if err != nil {
return err
}

log.Info().Msgf("Client config after session create: %v", clientConfig)
if clientConfig.LocalPort > 0 {
err := c.natPinger.PingProvider(
clientConfig.VpnConfig.OriginalRemoteIP,
clientConfig.VpnConfig.OriginalRemotePort,
clientConfig.LocalPort,
clientConfig.LocalPort+1,
c.pingerStop,
)
// TODO this backward compatibility check needs to be removed once we will start using port ranges for all peers.
if sessionConfig.LocalPort > 0 || len(sessionConfig.Ports) > 0 {
if len(c.ports) == 0 {
c.ports = []int{sessionConfig.LocalPort}
sessionConfig.Ports = []int{sessionConfig.RemotePort}
}

ip := sessionConfig.RemoteIP
localPorts := c.ports
remotePorts := sessionConfig.Ports

lPort, rPort, err := c.natPinger.PingProvider(ip, localPorts, remotePorts, sessionConfig.LocalPort)
if err != nil {
return err
return errors.Wrap(err, "could not ping provider")
}

sessionConfig.LocalPort = lPort
sessionConfig.RemotePort = rPort
}

newSession, clientConfig, err := c.createSession(options, sessionConfig)
if err != nil {
log.Info().Err(err).Msg("Client config factory error")
return err
}
log.Info().Interface("data", clientConfig).Msgf("Openvpn client configuration")

c.session = newSession
c.session.Start()
return nil
Expand All @@ -188,8 +188,6 @@ func (c *openvpnConnection) Stop() {
if c.session != nil {
c.session.Stop()
}
log.Info().Msg("Stopping NATProxy")
close(c.pingerStop)
})
}

Expand All @@ -212,8 +210,18 @@ func (c *openvpnConnection) GetConfig() (connection.ConsumerConfig, error) {
return nil, errors.Wrap(err, "failed to get consumer public IP")
}

ports, err := port.NewPool().AcquireMultiple(config.GetInt(config.FlagNATPunchingMaxTTL))
if err != nil {
return nil, err
}

for _, p := range ports {
c.ports = append(c.ports, p.Num())
}

return &openvpn.ConsumerConfig{
IP: publicIP,
IP: publicIP,
Ports: c.ports,
}, nil
}

Expand Down
34 changes: 23 additions & 11 deletions mobile/mysterium/wireguard_connection_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import (
"sync"
"time"

"github.com/mysteriumnetwork/node/config"
"github.com/mysteriumnetwork/node/core/connection"
"github.com/mysteriumnetwork/node/core/ip"
"github.com/mysteriumnetwork/node/core/port"
"github.com/mysteriumnetwork/node/nat/traversal"
"github.com/mysteriumnetwork/node/services/wireguard"
wireguard_connection "github.com/mysteriumnetwork/node/services/wireguard/connection"
Expand Down Expand Up @@ -68,7 +70,6 @@ func NewWireGuardConnection(opts wireGuardOptions, device wireguardDevice, ipRes

return &wireguardConnection{
done: make(chan struct{}),
pingerStop: make(chan struct{}),
stateCh: make(chan connection.State, 100),
opts: opts,
device: device,
Expand All @@ -80,9 +81,9 @@ func NewWireGuardConnection(opts wireGuardOptions, device wireguardDevice, ipRes
}

type wireguardConnection struct {
ports []int
closeOnce sync.Once
done chan struct{}
pingerStop chan struct{}
stateCh chan connection.State
opts wireGuardOptions
privateKey string
Expand Down Expand Up @@ -125,17 +126,19 @@ func (c *wireguardConnection) Start(options connection.ConnectOptions) (err erro
}
}()

if config.LocalPort > 0 {
err = c.natPinger.PingProvider(
config.Provider.Endpoint.IP.String(),
config.RemotePort,
config.LocalPort,
0,
c.pingerStop,
)
// TODO this backward compatibility check needs to be removed once we will start using port ranges for all peers.
if config.LocalPort > 0 || len(config.Ports) > 0 {
ip := config.Provider.Endpoint.IP.String()
localPorts := c.ports
remotePorts := config.Ports

lPort, rPort, err := c.natPinger.PingProvider(ip, localPorts, remotePorts, 0)
if err != nil {
return errors.Wrap(err, "could not ping provider")
}

config.LocalPort = lPort
config.Provider.Endpoint.Port = rPort
}

if err := c.device.Start(c.privateKey, config); err != nil {
Expand All @@ -162,7 +165,6 @@ func (c *wireguardConnection) Stop() {
c.device.Stop()
c.stateCh <- connection.NotConnected

close(c.pingerStop)
close(c.stateCh)
close(c.done)
})
Expand All @@ -184,11 +186,21 @@ func (c *wireguardConnection) GetConfig() (connection.ConsumerConfig, error) {
if err != nil {
return nil, errors.Wrap(err, "failed to get consumer public IP")
}

ports, err := port.NewPool().AcquireMultiple(config.GetInt(config.FlagNATPunchingMaxTTL))
if err != nil {
return nil, err
}

for _, p := range ports {
c.ports = append(c.ports, p.Num())
}
}

return wireguard.ConsumerConfig{
PublicKey: publicKey,
IP: publicIP,
Ports: c.ports,
}, nil
}

Expand Down
Loading

0 comments on commit 4b20b17

Please sign in to comment.