From 4e7cfdc52d7c6193d37546b69c861580fe333f87 Mon Sep 17 00:00:00 2001 From: theodorsm Date: Sat, 4 May 2024 13:33:45 +0200 Subject: [PATCH] Add handshake hooking Hooking for client/server hello and certificate request messages --- config.go | 18 ++++++ conn.go | 55 +++++++++--------- e2e/e2e_test.go | 126 ++++++++++++++++++++++++++++++++++++++++- flight1handler.go | 30 ++++++---- flight3handler.go | 30 ++++++---- flight4bhandler.go | 26 +++++---- flight4handler.go | 51 +++++++++++------ flight4handler_test.go | 68 ++++++++++++++++++++++ handshaker.go | 4 ++ state.go | 4 ++ 10 files changed, 336 insertions(+), 76 deletions(-) diff --git a/config.go b/config.go index 604a4d575..d765ecd91 100644 --- a/config.go +++ b/config.go @@ -14,6 +14,7 @@ import ( "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/logging" ) @@ -196,6 +197,23 @@ type Config struct { // If no PaddingLengthGenerator is specified, padding will not be applied. // https://datatracker.ietf.org/doc/html/rfc9146#section-4 PaddingLengthGenerator func(uint) uint + + // Handshake hooks: hooks can be used for testing invalid messages, + // mimicking other implementations or randomizing fields, which is valuable + // for applications that need censorship-resistance by making + // fingerprinting more difficult. + + // ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent + // from a client. The returned handshake message replaces the original message. + ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + + // ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent + // from a server. The returned handshake message replaces the original message. + ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message + + // CertificateRequestMessageHook, if not nil, is called when a Certificate Request + // message is sent from a server. The returned handshake message replaces the original message. + CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message } func defaultConnectContextMaker() (context.Context, func()) { diff --git a/conn.go b/conn.go index d4259f9e8..1ddc70963 100644 --- a/conn.go +++ b/conn.go @@ -176,32 +176,35 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co } hsCfg := &handshakeConfig{ - localPSKCallback: config.PSK, - localPSKIdentityHint: config.PSKIdentityHint, - localCipherSuites: cipherSuites, - localSignatureSchemes: signatureSchemes, - extendedMasterSecret: config.ExtendedMasterSecret, - localSRTPProtectionProfiles: config.SRTPProtectionProfiles, - serverName: serverName, - supportedProtocols: config.SupportedProtocols, - clientAuth: config.ClientAuth, - localCertificates: config.Certificates, - insecureSkipVerify: config.InsecureSkipVerify, - verifyPeerCertificate: config.VerifyPeerCertificate, - verifyConnection: config.VerifyConnection, - rootCAs: config.RootCAs, - clientCAs: config.ClientCAs, - customCipherSuites: config.CustomCipherSuites, - retransmitInterval: workerInterval, - log: logger, - initialEpoch: 0, - keyLogWriter: config.KeyLogWriter, - sessionStore: config.SessionStore, - ellipticCurves: curves, - localGetCertificate: config.GetCertificate, - localGetClientCertificate: config.GetClientCertificate, - insecureSkipHelloVerify: config.InsecureSkipVerifyHello, - connectionIDGenerator: config.ConnectionIDGenerator, + localPSKCallback: config.PSK, + localPSKIdentityHint: config.PSKIdentityHint, + localCipherSuites: cipherSuites, + localSignatureSchemes: signatureSchemes, + extendedMasterSecret: config.ExtendedMasterSecret, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + serverName: serverName, + supportedProtocols: config.SupportedProtocols, + clientAuth: config.ClientAuth, + localCertificates: config.Certificates, + insecureSkipVerify: config.InsecureSkipVerify, + verifyPeerCertificate: config.VerifyPeerCertificate, + verifyConnection: config.VerifyConnection, + rootCAs: config.RootCAs, + clientCAs: config.ClientCAs, + customCipherSuites: config.CustomCipherSuites, + retransmitInterval: workerInterval, + log: logger, + initialEpoch: 0, + keyLogWriter: config.KeyLogWriter, + sessionStore: config.SessionStore, + ellipticCurves: curves, + localGetCertificate: config.GetCertificate, + localGetClientCertificate: config.GetClientCertificate, + insecureSkipHelloVerify: config.InsecureSkipVerifyHello, + connectionIDGenerator: config.ConnectionIDGenerator, + clientHelloMessageHook: config.ClientHelloMessageHook, + serverHelloMessageHook: config.ServerHelloMessageHook, + certificateRequestMessageHook: config.CertificateRequestMessageHook, } // rfc5246#section-7.4.3 diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index e0ca90977..ec1253ec8 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -24,6 +24,8 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/transport/v3/test" ) @@ -33,7 +35,11 @@ const ( messageRetry = 200 * time.Millisecond ) -var errServerTimeout = errors.New("waiting on serverReady err: timeout") +var ( + errServerTimeout = errors.New("waiting on serverReady err: timeout") + errHookCiphersFailed = errors.New("hook failed to modify cipherlist") + errHookAPLNFailed = errors.New("hook failed to modify APLN extension") +) func randomPort(t testing.TB) int { t.Helper() @@ -569,6 +575,116 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), op comm.assert(t) } +func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("ClientHello hook", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + supportedList := []dtls.CipherSuiteID{ + dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, + modifiedCipher, + } + + ccfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *dtls.State) error { + if s.CipherSuiteID != modifiedCipher { + return errHookCiphersFailed + } + return nil + }, + CipherSuites: supportedList, + ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message { + ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)} + return &ch + }, + InsecureSkipVerify: true, + } + + scfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + CipherSuites: supportedList, + InsecureSkipVerify: true, + } + + for _, o := range opts { + o(ccfg) + o(scfg) + } + serverPort := randomPort(t) + comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) + comm.assert(t) + }) +} + +func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("ServerHello hook", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM} + + apln := "APLN" + + ccfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *dtls.State) error { + if s.NegotiatedProtocol != apln { + return errHookAPLNFailed + } + return nil + }, + CipherSuites: supportedList, + InsecureSkipVerify: true, + } + + scfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + CipherSuites: supportedList, + ServerHelloMessageHook: func(sh handshake.MessageServerHello) handshake.Message { + sh.Extensions = append(sh.Extensions, &extension.ALPN{ + ProtocolNameList: []string{apln}, + }) + return &sh + }, + InsecureSkipVerify: true, + } + + for _, o := range opts { + o(ccfg) + o(scfg) + } + serverPort := randomPort(t) + comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) + comm.assert(t) + }) +} + func TestPionE2ESimple(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion) } @@ -624,3 +740,11 @@ func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) { func TestPionE2ESimpleRSAClientCertCID(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } + +func TestPionE2ESimpleClientHelloHook(t *testing.T) { + testPionE2ESimpleClientHelloHook(t, serverPion, clientPion) +} + +func TestPionE2ESimpleServerHelloHook(t *testing.T) { + testPionE2ESimpleServerHelloHook(t, serverPion, clientPion) +} diff --git a/flight1handler.go b/flight1handler.go index 48bc88213..08a8f3921 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -133,23 +133,31 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) } + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageClientHello{ - Version: protocol.Version1_2, - SessionID: state.SessionID, - Cookie: state.cookie, - Random: state.localRandom, - CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), - CompressionMethods: defaultCompressionMethods(), - Extensions: extensions, - }, - }, + Content: &content, }, }, }, nil, nil diff --git a/flight3handler.go b/flight3handler.go index 90dc1a6e3..08549fb76 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -287,23 +287,31 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) } + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageClientHello{ - Version: protocol.Version1_2, - SessionID: state.SessionID, - Cookie: state.cookie, - Random: state.localRandom, - CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), - CompressionMethods: defaultCompressionMethods(), - Extensions: extensions, - }, - }, + Content: &content, }, }, }, nil, nil diff --git a/flight4bhandler.go b/flight4bhandler.go index 6b1b90469..d653f4d94 100644 --- a/flight4bhandler.go +++ b/flight4bhandler.go @@ -77,15 +77,21 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha } cipherSuiteID := uint16(state.cipherSuite.ID()) - serverHello := &handshake.Handshake{ - Message: &handshake.MessageServerHello{ - Version: protocol.Version1_2, - Random: state.localRandom, - SessionID: state.SessionID, - CipherSuiteID: &cipherSuiteID, - CompressionMethod: defaultCompressionMethods()[0], - Extensions: extensions, - }, + var serverHello handshake.Handshake + + serverHelloMessage := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + if cfg.serverHelloMessageHook != nil { + serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)} + } else { + serverHello = handshake.Handshake{Message: serverHelloMessage} } serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) @@ -112,7 +118,7 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: serverHello, + Content: &serverHello, }, }, &packet{ diff --git a/flight4handler.go b/flight4handler.go index cd8f2884a..86f21464b 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -269,21 +269,29 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha } } + serverHello := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.serverHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHello)} + } else { + content = handshake.Handshake{Message: serverHello} + } + pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageServerHello{ - Version: protocol.Version1_2, - Random: state.localRandom, - SessionID: state.SessionID, - CipherSuiteID: &cipherSuiteID, - CompressionMethod: defaultCompressionMethods()[0], - Extensions: extensions, - }, - }, + Content: &content, }, }) @@ -354,18 +362,27 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } + + certReq := &handshake.MessageCertificateRequest{ + CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, + SignatureHashAlgorithms: cfg.localSignatureSchemes, + CertificateAuthoritiesNames: certificateAuthorities, + } + + var content handshake.Handshake + + if cfg.certificateRequestMessageHook != nil { + content = handshake.Handshake{Message: cfg.certificateRequestMessageHook(*certReq)} + } else { + content = handshake.Handshake{Message: certReq} + } + pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageCertificateRequest{ - CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, - SignatureHashAlgorithms: cfg.localSignatureSchemes, - CertificateAuthoritiesNames: certificateAuthorities, - }, - }, + Content: &content, }, }) } diff --git a/flight4handler_test.go b/flight4handler_test.go index 304c82eff..f4446c40b 100644 --- a/flight4handler_test.go +++ b/flight4handler_test.go @@ -5,10 +5,15 @@ package dtls import ( "context" + "crypto/tls" + "errors" "testing" "time" "github.com/pion/dtls/v2/internal/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/transport/v3/test" @@ -16,6 +21,8 @@ import ( type flight4TestMockFlightConn struct{} +var errHookCertReqFailed = errors.New("hook failed to modify SignatureHashAlgorithms") + func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } @@ -117,3 +124,64 @@ func TestFlight4_Process_CertificateVerify(t *testing.T) { t.Fatal(err) } } + +func TestFlight4_CertificateRequestHook(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(5 * time.Second) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + localKeypair, err := elliptic.GenerateKeypair(elliptic.P256) + if err != nil { + t.Fatal(err) + } + + mockConn := &flight4TestMockFlightConn{} + state := &State{ + cipherSuite: &flight4TestMockCipherSuite{t: t}, + localKeypair: localKeypair, + } + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + cfg := &handshakeConfig{ + localCertificates: []tls.Certificate{cert}, + localSignatureSchemes: signaturehash.Algorithms(), + clientAuth: 1, + certificateRequestMessageHook: func(mcr handshake.MessageCertificateRequest) handshake.Message { + mcr.SignatureHashAlgorithms = []signaturehash.Algorithm{} + return &mcr + }, + } + + pkts, _, err := flight4Generate(mockConn, state, nil, cfg) + if err != nil { + t.Fatal(err) + } + + for _, p := range pkts { + if h, ok := p.record.Content.(*handshake.Handshake); ok { + if h.Message.Type() == handshake.TypeCertificateRequest { + mcr := &handshake.MessageCertificateRequest{} + msg, err := h.Message.Marshal() + if err != nil { + t.Fatal(err) + } + err = mcr.Unmarshal(msg) + if err != nil { + t.Fatal(err) + } + if len(mcr.SignatureHashAlgorithms) == 0 { + return + } + } + } + } + t.Fatal(errHookCertReqFailed) +} diff --git a/handshaker.go b/handshaker.go index 46fbd38bd..317cb4e4e 100644 --- a/handshaker.go +++ b/handshaker.go @@ -125,6 +125,10 @@ type handshakeConfig struct { initialEpoch uint16 mu sync.Mutex + + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message } type flightConn interface { diff --git a/state.go b/state.go index b04045ac9..35a5b64fb 100644 --- a/state.go +++ b/state.go @@ -82,6 +82,7 @@ type serializedState struct { LocalConnectionID []byte RemoteConnectionID []byte IsClient bool + NegotiatedProtocol string } func (s *State) clone() *State { @@ -113,6 +114,7 @@ func (s *State) serialize() *serializedState { LocalConnectionID: s.localConnectionID, RemoteConnectionID: s.remoteConnectionID, IsClient: s.isClient, + NegotiatedProtocol: s.NegotiatedProtocol, } } @@ -157,6 +159,8 @@ func (s *State) deserialize(serialized serializedState) { s.remoteConnectionID = serialized.RemoteConnectionID s.SessionID = serialized.SessionID + + s.NegotiatedProtocol = serialized.NegotiatedProtocol } func (s *State) initCipherSuite() error {