Skip to content

Commit

Permalink
- Minimum changes to be able to get the IP address in the UDP-PSK
Browse files Browse the repository at this point in the history
handshake function to implement a Brute Force Attack protection.
  • Loading branch information
tonisole committed Nov 22, 2023
1 parent a8f7062 commit 7f0f55b
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 24 deletions.
3 changes: 2 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
Expand Down Expand Up @@ -219,7 +220,7 @@ var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P3

// PSKCallback is called once we have the remote's PSKIdentityHint.
// If the remote provided none it will be nil
type PSKCallback func([]byte) ([]byte, error)
type PSKCallback func([]byte, net.Addr) ([]byte, error)

// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
Expand Down
7 changes: 4 additions & 3 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/rsa"
"crypto/tls"
"errors"
"net"
"testing"

"github.com/pion/dtls/v2/pkg/crypto/selfsign"
Expand Down Expand Up @@ -47,7 +48,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, valid cipher suites": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand All @@ -56,7 +57,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, no PSK cipher suite": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand All @@ -66,7 +67,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, no non-PSK cipher suite": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand Down
24 changes: 12 additions & 12 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func TestPSK(t *testing.T) {
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
if !bytes.Equal(test.ServerIdentity, hint) {
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113
}
Expand All @@ -557,7 +557,7 @@ func TestPSK(t *testing.T) {
}()

config := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
if !bytes.Equal(clientIdentity, hint) {
return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint)
}
Expand Down Expand Up @@ -620,7 +620,7 @@ func TestPSKHintFail(t *testing.T) {
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
Expand All @@ -632,7 +632,7 @@ func TestPSKHintFail(t *testing.T) {
}()

config := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
Expand Down Expand Up @@ -1556,7 +1556,7 @@ func TestCertificateAndPSKServer(t *testing.T) {
go func() {
config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}
if test.ClientPSK {
config.PSK = func([]byte) ([]byte, error) {
config.PSK = func([]byte, net.Addr) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
}
config.PSKIdentityHint = []byte{0x00}
Expand All @@ -1569,7 +1569,7 @@ func TestCertificateAndPSKServer(t *testing.T) {

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256},
PSK: func([]byte) ([]byte, error) {
PSK: func([]byte, net.Addr) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
},
}
Expand Down Expand Up @@ -1614,8 +1614,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and no certificate specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
Expand All @@ -1625,8 +1625,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and certificate specified",
ClientHasCertificate: true,
ServerHasCertificate: true,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
Expand All @@ -1636,8 +1636,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and no identity specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: nil,
ServerPSKIdentity: nil,
WantClientError: errPSKAndIdentityMustBeSetForClient,
Expand Down
2 changes: 1 addition & 1 deletion e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtls
defer cancel()

cfg := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
Expand Down
2 changes: 1 addition & 1 deletion examples/dial/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Server's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
2 changes: 1 addition & 1 deletion examples/dial/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Server's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
2 changes: 1 addition & 1 deletion examples/listen/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
2 changes: 1 addition & 1 deletion examples/listen/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
4 changes: 2 additions & 2 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha
return flight5b, nil, nil
}

func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
func handleServerKeyExchange(c flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
var err error
if state.cipherSuite == nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
}
if cfg.localPSKCallback != nil {
var psk []byte
if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
if psk, err = cfg.localPSKCallback(h.IdentityHint, c.(*Conn).rAddr); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.IdentityHint = h.IdentityHint
Expand Down
2 changes: 1 addition & 1 deletion flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
var preMasterSecret []byte
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
var psk []byte
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint, c.(*Conn).rAddr); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.IdentityHint = clientKeyExchange.IdentityHint
Expand Down

0 comments on commit 7f0f55b

Please sign in to comment.