Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handshake hooking capabilities #631

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/logging"
)

Expand Down Expand Up @@ -196,6 +197,23 @@ type Config struct {
// If no PaddingLengthGenerator is specified, padding will not be applied.
// https://datatracker.ietf.org/doc/html/rfc9146#section-4
PaddingLengthGenerator func(uint) uint

// Handshake hooks: hooks can be used for testing invalid messages,
// mimicking other implementations or randomizing fields, which is valuable
// for applications that need censorship-resistance by making
// fingerprinting more difficult.

// ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent
// from a client. The returned handshake message replaces the original message.
ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message

// ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent
// from a server. The returned handshake message replaces the original message.
ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message

// CertificateRequestMessageHook, if not nil, is called when a Certificate Request
// message is sent from a server. The returned handshake message replaces the original message.
CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
55 changes: 29 additions & 26 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,32 +176,35 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
}

hsCfg := &handshakeConfig{
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
verifyConnection: config.VerifyConnection,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
ellipticCurves: curves,
localGetCertificate: config.GetCertificate,
localGetClientCertificate: config.GetClientCertificate,
insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
connectionIDGenerator: config.ConnectionIDGenerator,
localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,
localCipherSuites: cipherSuites,
localSignatureSchemes: signatureSchemes,
extendedMasterSecret: config.ExtendedMasterSecret,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
serverName: serverName,
supportedProtocols: config.SupportedProtocols,
clientAuth: config.ClientAuth,
localCertificates: config.Certificates,
insecureSkipVerify: config.InsecureSkipVerify,
verifyPeerCertificate: config.VerifyPeerCertificate,
verifyConnection: config.VerifyConnection,
rootCAs: config.RootCAs,
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
ellipticCurves: curves,
localGetCertificate: config.GetCertificate,
localGetClientCertificate: config.GetClientCertificate,
insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
connectionIDGenerator: config.ConnectionIDGenerator,
clientHelloMessageHook: config.ClientHelloMessageHook,
serverHelloMessageHook: config.ServerHelloMessageHook,
certificateRequestMessageHook: config.CertificateRequestMessageHook,
}

// rfc5246#section-7.4.3
Expand Down
126 changes: 125 additions & 1 deletion e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/transport/v3/test"
)

Expand All @@ -33,7 +35,11 @@ const (
messageRetry = 200 * time.Millisecond
)

var errServerTimeout = errors.New("waiting on serverReady err: timeout")
var (
errServerTimeout = errors.New("waiting on serverReady err: timeout")
errHookCiphersFailed = errors.New("hook failed to modify cipherlist")
errHookAPLNFailed = errors.New("hook failed to modify APLN extension")
)

func randomPort(t testing.TB) int {
t.Helper()
Expand Down Expand Up @@ -569,6 +575,116 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), op
comm.assert(t)
}

func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

t.Run("ClientHello hook", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}

modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
supportedList := []dtls.CipherSuiteID{
dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
modifiedCipher,
}

ccfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
VerifyConnection: func(s *dtls.State) error {
if s.CipherSuiteID != modifiedCipher {
return errHookCiphersFailed
}
return nil
},
CipherSuites: supportedList,
ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message {
ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)}
return &ch
},
InsecureSkipVerify: true,
}

scfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: supportedList,
InsecureSkipVerify: true,
}

for _, o := range opts {
o(ccfg)
o(scfg)
}
serverPort := randomPort(t)
comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
defer comm.cleanup(t)
comm.assert(t)
})
}

func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

t.Run("ServerHello hook", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
if err != nil {
t.Fatal(err)
}

supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM}

apln := "APLN"

ccfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
VerifyConnection: func(s *dtls.State) error {
if s.NegotiatedProtocol != apln {
return errHookAPLNFailed
}
return nil
},
CipherSuites: supportedList,
InsecureSkipVerify: true,
}

scfg := &dtls.Config{
Certificates: []tls.Certificate{cert},
CipherSuites: supportedList,
ServerHelloMessageHook: func(sh handshake.MessageServerHello) handshake.Message {
sh.Extensions = append(sh.Extensions, &extension.ALPN{
ProtocolNameList: []string{apln},
})
return &sh
},
InsecureSkipVerify: true,
}

for _, o := range opts {
o(ccfg)
o(scfg)
}
serverPort := randomPort(t)
comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
defer comm.cleanup(t)
comm.assert(t)
})
}

func TestPionE2ESimple(t *testing.T) {
testPionE2ESimple(t, serverPion, clientPion)
}
Expand Down Expand Up @@ -624,3 +740,11 @@ func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) {
func TestPionE2ESimpleRSAClientCertCID(t *testing.T) {
testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8)))
}

func TestPionE2ESimpleClientHelloHook(t *testing.T) {
testPionE2ESimpleClientHelloHook(t, serverPion, clientPion)
}

func TestPionE2ESimpleServerHelloHook(t *testing.T) {
testPionE2ESimpleServerHelloHook(t, serverPion, clientPion)
}
30 changes: 19 additions & 11 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,31 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

clientHello := &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
}

var content handshake.Handshake

if cfg.clientHelloMessageHook != nil {
content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)}
} else {
content = handshake.Handshake{Message: clientHello}
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
Content: &content,
},
},
}, nil, nil
Expand Down
30 changes: 19 additions & 11 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,23 +287,31 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha
extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID})
}

clientHello := &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
}

var content handshake.Handshake

if cfg.clientHelloMessageHook != nil {
content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)}
} else {
content = handshake.Handshake{Message: clientHello}
}

return []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &handshake.Handshake{
Message: &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
},
},
Content: &content,
},
},
}, nil, nil
Expand Down
26 changes: 16 additions & 10 deletions flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,21 @@
}

cipherSuiteID := uint16(state.cipherSuite.ID())
serverHello := &handshake.Handshake{
Message: &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
SessionID: state.SessionID,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
},
var serverHello handshake.Handshake

serverHelloMessage := &handshake.MessageServerHello{
Version: protocol.Version1_2,
Random: state.localRandom,
SessionID: state.SessionID,
CipherSuiteID: &cipherSuiteID,
CompressionMethod: defaultCompressionMethods()[0],
Extensions: extensions,
}

if cfg.serverHelloMessageHook != nil {
serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)}

Check warning on line 92 in flight4bhandler.go

View check run for this annotation

Codecov / codecov/patch

flight4bhandler.go#L92

Added line #L92 was not covered by tests
} else {
serverHello = handshake.Handshake{Message: serverHelloMessage}
}

serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence)
Expand All @@ -112,7 +118,7 @@
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: serverHello,
Content: &serverHello,
},
},
&packet{
Expand Down
Loading
Loading