diff --git a/config.go b/config.go index 604a4d575..6482a5e9c 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "crypto/x509" "io" + "net" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" @@ -219,7 +220,7 @@ var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P3 // PSKCallback is called once we have the remote's PSKIdentityHint. // If the remote provided none it will be nil -type PSKCallback func([]byte) ([]byte, error) +type PSKCallback func([]byte, net.Addr) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. diff --git a/config_test.go b/config_test.go index 811427a0c..4423c83a3 100644 --- a/config_test.go +++ b/config_test.go @@ -9,6 +9,7 @@ import ( "crypto/rsa" "crypto/tls" "errors" + "net" "testing" "github.com/pion/dtls/v2/pkg/crypto/selfsign" @@ -47,7 +48,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, valid cipher suites": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, @@ -56,7 +57,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, no PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, @@ -66,7 +67,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, no non-PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, diff --git a/conn_test.go b/conn_test.go index d3226d876..3706932cb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -540,7 +540,7 @@ func TestPSK(t *testing.T) { ca, cb := dpipe.Pipe() go func() { conf := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113 } @@ -557,7 +557,7 @@ func TestPSK(t *testing.T) { }() config := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { if !bytes.Equal(clientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint) } @@ -620,7 +620,7 @@ func TestPSKHintFail(t *testing.T) { ca, cb := dpipe.Pipe() go func() { conf := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, @@ -632,7 +632,7 @@ func TestPSKHintFail(t *testing.T) { }() config := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, @@ -1556,7 +1556,7 @@ func TestCertificateAndPSKServer(t *testing.T) { go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} if test.ClientPSK { - config.PSK = func([]byte) ([]byte, error) { + config.PSK = func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil } config.PSKIdentityHint = []byte{0x00} @@ -1569,7 +1569,7 @@ func TestCertificateAndPSKServer(t *testing.T) { config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256}, - PSK: func([]byte) ([]byte, error) { + PSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, } @@ -1614,8 +1614,8 @@ func TestPSKConfiguration(t *testing.T) { Name: "PSK and no certificate specified", ClientHasCertificate: false, ServerHasCertificate: false, - ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, - ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, @@ -1625,8 +1625,8 @@ func TestPSKConfiguration(t *testing.T) { Name: "PSK and certificate specified", ClientHasCertificate: true, ServerHasCertificate: true, - ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, - ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, @@ -1636,8 +1636,8 @@ func TestPSKConfiguration(t *testing.T) { Name: "PSK and no identity specified", ClientHasCertificate: false, ServerHasCertificate: false, - ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, - ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, + ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: nil, ServerPSKIdentity: nil, WantClientError: errPSKAndIdentityMustBeSetForClient, diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 4a8dd5d54..7ecbbf522 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -325,7 +325,7 @@ func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtls defer cancel() cfg := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go index 10e547706..a4f4c78ac 100644 --- a/examples/dial/cid/main.go +++ b/examples/dial/cid/main.go @@ -24,7 +24,7 @@ func main() { // Prepare the configuration of the DTLS connection config := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, diff --git a/examples/dial/psk/main.go b/examples/dial/psk/main.go index 492ecdd61..a611cc08a 100644 --- a/examples/dial/psk/main.go +++ b/examples/dial/psk/main.go @@ -24,7 +24,7 @@ func main() { // Prepare the configuration of the DTLS connection config := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go index 770bbcfa4..bb9397e6e 100644 --- a/examples/listen/cid/main.go +++ b/examples/listen/cid/main.go @@ -28,7 +28,7 @@ func main() { // Prepare the configuration of the DTLS connection config := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index 66f099693..81764487e 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -28,7 +28,7 @@ func main() { // Prepare the configuration of the DTLS connection config := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func(hint []byte, addr net.Addr) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, diff --git a/flight3handler.go b/flight3handler.go index 90dc1a6e3..30039d43a 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -202,14 +202,14 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha return flight5b, nil, nil } -func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { +func handleServerKeyExchange(c flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } if cfg.localPSKCallback != nil { var psk []byte - if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { + if psk, err = cfg.localPSKCallback(h.IdentityHint, c.(*Conn).rAddr); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = h.IdentityHint diff --git a/flight4handler.go b/flight4handler.go index cd8f2884a..675be62f3 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -107,7 +107,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh var preMasterSecret []byte if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { var psk []byte - if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { + if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint, c.(*Conn).rAddr); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = clientKeyExchange.IdentityHint