Skip to content

Commit

Permalink
Add handshake hooking
Browse files Browse the repository at this point in the history
Hooking for client/server hello and certificate request messages
  • Loading branch information
theodorsm authored and at-wat committed May 13, 2024
1 parent 2c36d63 commit 8738ce1
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 76 deletions.
18 changes: 18 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/logging"
)

Expand Down Expand Up @@ -196,6 +197,23 @@ type Config struct {
// If no PaddingLengthGenerator is specified, padding will not be applied.
// https://datatracker.ietf.org/doc/html/rfc9146#section-4
PaddingLengthGenerator func(uint) uint

// Handshake hooks: hooks can be used for testing invalid messages,
// mimicking other implementations or randomizing fields, which is valuable
// for applications that need censorship-resistance by making
// fingerprinting more difficult.

// ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent
// from a client. The returned handshake message replaces the original message.
ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message

// ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent
// from a server. The returned handshake message replaces the original message.
ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message

// CertificateRequestMessageHook, if not nil, is called when a Certificate Request
// message is sent from a server. The returned handshake message replaces the original message.
CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
55 changes: 29 additions & 26 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,32 +176,35 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
}

hsCfg := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
verifyConnection: config.VerifyConnection,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
ellipticCurves: curves,
localGetCertificate: config.GetCertificate,
localGetClientCertificate: config.GetClientCertificate,
insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
connectionIDGenerator: config.ConnectionIDGenerator,
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
verifyConnection: config.VerifyConnection,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
ellipticCurves: curves,
localGetCertificate: config.GetCertificate,
localGetClientCertificate: config.GetClientCertificate,
insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
connectionIDGenerator: config.ConnectionIDGenerator,
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
}

// rfc5246#section-7.4.3
Expand Down
126 changes: 125 additions & 1 deletion e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/transport/v3/test"
)

Expand All @@ -33,7 +35,11 @@ const (
messageRetry = 200 * time.Millisecond
)

var errServerTimeout = errors.New("waiting on serverReady err: timeout")
var (
errServerTimeout = errors.New("waiting on serverReady err: timeout")
errHookCiphersFailed = errors.New("hook failed to modify cipherlist")
errHookAPLNFailed = errors.New("hook failed to modify APLN extension")
)

func randomPort(t testing.TB) int {
t.Helper()
Expand Down Expand Up @@ -569,6 +575,116 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), op
comm.assert(t)
}

func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

t.Run("ClientHello hook", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}

modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
supportedList := []dtls.CipherSuiteID{
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
modifiedCipher,
}

ccfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
VerifyConnection: func(s *dtls.State) error {
if s.CipherSuiteID != modifiedCipher {
return errHookCiphersFailed
}
return nil
},
CipherSuites: supportedList,
ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message {
ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)}
return &ch
},
InsecureSkipVerify: true,
}

scfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: supportedList,
InsecureSkipVerify: true,
}

for _, o := range opts {
o(ccfg)
o(scfg)
}
serverPort := randomPort(t)
comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
defer comm.cleanup(t)
comm.assert(t)
})
}

func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

t.Run("ServerHello hook", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}

supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM}

apln := "APLN"

ccfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
VerifyConnection: func(s *dtls.State) error {
if s.NegotiatedProtocol != apln {
return errHookAPLNFailed
}
return nil
},
CipherSuites: supportedList,
InsecureSkipVerify: true,
}

scfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: supportedList,
ServerHelloMessageHook: func(sh handshake.MessageServerHello) handshake.Message {
sh.Extensions = append(sh.Extensions, &extension.ALPN{
ProtocolNameList: []string{apln},
})
return &sh
},
InsecureSkipVerify: true,
}

for _, o := range opts {
o(ccfg)
o(scfg)
}
serverPort := randomPort(t)
comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
defer comm.cleanup(t)
comm.assert(t)
})
}

func TestPionE2ESimple(t *testing.T) {
testPionE2ESimple(t, serverPion, clientPion)
}
Expand Down Expand Up @@ -624,3 +740,11 @@ func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) {
func TestPionE2ESimpleRSAClientCertCID(t *testing.T) {
testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8)))
}

func TestPionE2ESimpleClientHelloHook(t *testing.T) {
testPionE2ESimpleClientHelloHook(t, serverPion, clientPion)
}

func TestPionE2ESimpleServerHelloHook(t *testing.T) {
testPionE2ESimpleServerHelloHook(t, serverPion, clientPion)
}
30 changes: 19 additions & 11 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,31 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

clientHello := &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
}

var content handshake.Handshake

if cfg.clientHelloMessageHook != nil {
content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)}
} else {
content = handshake.Handshake{Message: clientHello}
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
Content: &content,
},
},
}, nil, nil
Expand Down
30 changes: 19 additions & 11 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,23 +287,31 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

clientHello := &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
}

var content handshake.Handshake

if cfg.clientHelloMessageHook != nil {
content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)}
} else {
content = handshake.Handshake{Message: clientHello}
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
Content: &content,
},
},
}, nil, nil
Expand Down
26 changes: 16 additions & 10 deletions flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,21 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha
}

cipherSuiteID := uint16(state.cipherSuite.ID())
serverHello := &handshake.Handshake{
Message: &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
SessionID: state.SessionID,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
},
var serverHello handshake.Handshake

serverHelloMessage := &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
SessionID: state.SessionID,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
}

if cfg.serverHelloMessageHook != nil {
serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)}
} else {
serverHello = handshake.Handshake{Message: serverHelloMessage}
}

serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
Expand All @@ -112,7 +118,7 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: serverHello,
Content: &serverHello,
},
},
&packet{
Expand Down
Loading

0 comments on commit 8738ce1

Please sign in to comment.