diff --git a/.golangci.yml b/.golangci.yml index a3235bec2..88cb4fbf9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,17 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -46,18 +61,17 @@ linters: - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - err113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -65,9 +79,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -75,28 +95,22 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -104,8 +118,7 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! diff --git a/bench_test.go b/bench_test.go index 57fe93ff1..8d90786cb 100644 --- a/bench_test.go +++ b/bench_test.go @@ -37,6 +37,7 @@ func TestSimpleReadWrite(t *testing.T) { }, false) if sErr != nil { t.Error(sErr) + return } buf := make([]byte, 1024) @@ -71,8 +72,10 @@ func TestSimpleReadWrite(t *testing.T) { } } -func benchmarkConn(b *testing.B, n int64) { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { +func benchmarkConn(b *testing.B, payloadSize int64) { + b.Helper() + + b.Run(fmt.Sprintf("%d", payloadSize), func(b *testing.B) { ctx := context.Background() ca, cb := dpipe.Pipe() @@ -84,6 +87,7 @@ func benchmarkConn(b *testing.B, n int64) { }, false) if err != nil { b.Error(sErr) + return } server <- s @@ -91,11 +95,13 @@ func benchmarkConn(b *testing.B, n int64) { if err != nil { b.Fatal(err) } - hw := make([]byte, n) + hw := make([]byte, payloadSize) b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { - client, cErr := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false) + client, cErr := testClient( + ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false, + ) if cErr != nil { b.Error(err) } diff --git a/certificate.go b/certificate.go index 7e184dfd5..524b8e063 100644 --- a/certificate.go +++ b/certificate.go @@ -43,7 +43,8 @@ type CertificateRequestInfo struct { // SupportsCertificate returns nil if the provided certificate is supported by // the server that sent the CertificateRequest. Otherwise, it returns an error // describing the reason for the incompatibility. -// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { if len(cri.AcceptableCAs) == 0 { return nil @@ -66,6 +67,7 @@ func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error } } } + return errNotAcceptableCertificateChain } @@ -91,6 +93,7 @@ func (c *handshakeConfig) setNameToCertificateLocked() { c.nameToCertificate = nameToCertificate } +//nolint:cyclop func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() @@ -141,7 +144,8 @@ func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls return &c.localCertificates[0], nil } -// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() @@ -154,6 +158,7 @@ func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tl if err := cri.SupportsCertificate(&chain); err != nil { continue } + return &chain, nil } diff --git a/cipher_suite.go b/cipher_suite.go index af95dec2e..0a40918e0 100644 --- a/cipher_suite.go +++ b/cipher_suite.go @@ -16,49 +16,63 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// CipherSuiteID is an ID for our supported CipherSuites +// CipherSuiteID is an ID for our supported CipherSuites. type CipherSuiteID = ciphersuite.ID -// Supported Cipher Suites +// Supported Cipher Suites. const ( // AES-128-CCM - TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck - TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 // AES-128-GCM-SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 // AES-256-CBC-SHA - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck - - TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck - - TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 + + //nolint:revive,stylecheck + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ) -// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite +// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite. type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType -// AuthenticationType Enums +// AuthenticationType Enums. const ( CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous ) -// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite +// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite. type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm -// CipherSuiteKeyExchangeAlgorithm Bitmask +// CipherSuiteKeyExchangeAlgorithm Bitmask. const ( CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk @@ -67,7 +81,7 @@ const ( var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 -// CipherSuite is an interface that all DTLS CipherSuites must satisfy +// CipherSuite is an interface that all DTLS CipherSuites must satisfy. type CipherSuite interface { // String of CipherSuite, only used for logging String() string @@ -108,13 +122,14 @@ func CipherSuiteName(id CipherSuiteID) string { if suite != nil { return suite.String() } + return fmt.Sprintf("0x%04X", uint16(id)) } // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml // A cipherSuite is a specific combination of key agreement, cipher and MAC // function. -func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { +func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { //nolint:cyclop switch id { //nolint:exhaustive case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() @@ -157,7 +172,7 @@ func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) Ciph return nil } -// CipherSuites we support in order of preference +// CipherSuites we support in order of preference. func defaultCipherSuites() []CipherSuite { return []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, @@ -191,10 +206,16 @@ func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { for _, c := range cipherSuites { rtrn = append(rtrn, uint16(c.ID())) } + return rtrn } -func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) { +//nolint:cyclop +func parseCipherSuites( + userSelectedSuites []CipherSuiteID, + customCipherSuites func() []CipherSuite, + includeCertificateSuites, includePSKSuites bool, +) ([]CipherSuite, error) { cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { cipherSuites := []CipherSuite{} for _, id := range ids { @@ -204,6 +225,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites fu } cipherSuites = append(cipherSuites, c) } + return cipherSuites, nil } @@ -272,5 +294,6 @@ func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []Ciph filtered = append(filtered, c) } } + return filtered } diff --git a/cipher_suite_go114.go b/cipher_suite_go114.go index fd46d7bd9..e7a324147 100644 --- a/cipher_suite_go114.go +++ b/cipher_suite_go114.go @@ -11,10 +11,10 @@ import ( ) // VersionDTLS12 is the DTLS version in the same style as -// VersionTLSXX from crypto/tls +// VersionTLSXX from crypto/tls. const VersionDTLS12 = 0xfefd -// Convert from our cipherSuite interface to a tls.CipherSuite struct +// Convert from our cipherSuite interface to a tls.CipherSuite struct. func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { return &tls.CipherSuite{ ID: uint16(c.ID()), @@ -33,6 +33,7 @@ func CipherSuites() []*tls.CipherSuite { for i, c := range suites { res[i] = toTLSCipherSuite(c) } + return res } @@ -40,5 +41,6 @@ func CipherSuites() []*tls.CipherSuite { // this package and which have security issues. func InsecureCipherSuites() []*tls.CipherSuite { var res []*tls.CipherSuite + return res } diff --git a/cipher_suite_go114_test.go b/cipher_suite_go114_test.go index 35c4b1ef6..e93b760c5 100644 --- a/cipher_suite_go114_test.go +++ b/cipher_suite_go114_test.go @@ -30,25 +30,25 @@ func TestCipherSuites(t *testing.T) { i := i s := s t.Run(s.String(), func(t *testing.T) { - c := theirs[i] - if c.ID != uint16(s.ID()) { - t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), c.ID) + cipher := theirs[i] + if cipher.ID != uint16(s.ID()) { + t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), cipher.ID) } - if c.Name != s.String() { - t.Fatalf("Expected Name: %s, got %s", s.String(), c.Name) + if cipher.Name != s.String() { + t.Fatalf("Expected Name: %s, got %s", s.String(), cipher.Name) } - if len(c.SupportedVersions) != 1 { - t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(c.SupportedVersions)) + if len(cipher.SupportedVersions) != 1 { + t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(cipher.SupportedVersions)) } - if c.SupportedVersions[0] != VersionDTLS12 { - t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, c.SupportedVersions[0]) + if cipher.SupportedVersions[0] != VersionDTLS12 { + t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, cipher.SupportedVersions[0]) } - if c.Insecure { - t.Fatalf("Expected Insecure %t, got %t", false, c.Insecure) + if cipher.Insecure { + t.Fatalf("Expected Insecure %t, got %t", false, cipher.Insecure) } }) } diff --git a/cipher_suite_test.go b/cipher_suite_test.go index bd1f803ec..c4fd4840a 100644 --- a/cipher_suite_test.go +++ b/cipher_suite_test.go @@ -38,7 +38,7 @@ func TestAllCipherSuites(t *testing.T) { } } -// CustomCipher that is just used to assert Custom IDs work +// CustomCipher that is just used to assert Custom IDs work. type testCustomCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 authenticationType CipherSuiteAuthenticationType @@ -52,7 +52,7 @@ func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationTy return t.authenticationType } -// Assert that two connections that pass in a CipherSuite with a CustomID works +// Assert that two connections that pass in a CipherSuite with a CustomID works. func TestCustomCipherSuite(t *testing.T) { type result struct { c *Conn @@ -68,14 +68,14 @@ func TestCustomCipherSuite(t *testing.T) { defer cancel() ca, cb := dpipe.Pipe() - c := make(chan result) + resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) - c <- result{client, err} + resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ @@ -83,7 +83,7 @@ func TestCustomCipherSuite(t *testing.T) { CustomCipherSuites: cipherFactory, }, true) - clientResult := <-c + clientResult := <-resultCh if err != nil { t.Error(err) diff --git a/config.go b/config.go index 54a86c0ee..39b5cc38c 100644 --- a/config.go +++ b/config.go @@ -217,7 +217,8 @@ type Config struct { // message is sent from a server. The returned handshake message replaces the original message. CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message - // OnConnectionAttempt is fired Whenever a connection attempt is made, the server or application can call this callback function. + // OnConnectionAttempt is fired Whenever a connection attempt is made, + // the server or application can call this callback function. // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. // If the callback function returns an error, the connection attempt will be aborted. @@ -233,14 +234,14 @@ const defaultMTU = 1200 // bytes var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals // PSKCallback is called once we have the remote's PSKIdentityHint. -// If the remote provided none it will be nil +// If the remote provided none it will be nil. type PSKCallback func([]byte) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. type ClientAuthType int -// ClientAuthType enums +// ClientAuthType enums. const ( NoClientCert ClientAuthType = iota RequestClientCert @@ -250,17 +251,17 @@ const ( ) // ExtendedMasterSecretType declares the policy the client and server -// will follow for the Extended Master Secret extension +// will follow for the Extended Master Secret extension. type ExtendedMasterSecretType int -// ExtendedMasterSecretType enums +// ExtendedMasterSecretType enums. const ( RequestExtendedMasterSecret ExtendedMasterSecretType = iota RequireExtendedMasterSecret DisableExtendedMasterSecret ) -func validateConfig(config *Config) error { +func validateConfig(config *Config) error { //nolint:cyclop switch { case config == nil: return errNoConfigProvided @@ -283,6 +284,9 @@ func validateConfig(config *Config) error { } } - _, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + _, err := parseCipherSuites( + config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, + ) + return err } diff --git a/config_test.go b/config_test.go index 99e25a4c2..b01de1442 100644 --- a/config_test.go +++ b/config_test.go @@ -14,26 +14,30 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) -func TestValidateConfig(t *testing.T) { +func TestValidateConfig(t *testing.T) { //nolint:cyclop cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), self signed certificate not generated", err) + return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA parameters not generated", err) + return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA private key not generated", err) + return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), RSA private key not generated", err) + return } cases := map[string]struct { diff --git a/conn.go b/conn.go index 26c56f2e0..c7cd9a3cb 100644 --- a/conn.go +++ b/conn.go @@ -33,10 +33,10 @@ const ( sessionLength = 32 defaultNamedCurve = elliptic.X25519 inboundBufferSize = 8192 - // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 + // Default replay protection window is specified by RFC 6347 Section 4.1.2.6. defaultReplayProtectionWindow = 64 - // maxAppDataPacketQueueSize is the maximum number of app data packets we will - // enqueue before the handshake is completed + // maxAppDataPacketQueueSize is the maximum number of app data packets we will. + // enqueue before the handshake is completed. maxAppDataPacketQueueSize = 100 ) @@ -59,7 +59,7 @@ type recvHandshakeState struct { isRetransmit bool } -// Conn represents a DTLS connection +// Conn represents a DTLS connection. type Conn struct { lock sync.RWMutex // Internal lock (must not be public) nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from @@ -99,7 +99,14 @@ type Conn struct { handshakeConfig *handshakeConfig } -func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) { +//nolint:cyclop +func createConn( + nextConn net.PacketConn, + rAddr net.Addr, + config *Config, + isClient bool, + resumeState *State, +) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -130,7 +137,12 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien paddingLengthGenerator = func(uint) uint { return 0 } } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + cipherSuites, err := parseCipherSuites( + config.CipherSuites, + config.CustomCipherSuites, + config.includeCertificateSuites(), + config.PSK != nil, + ) if err != nil { return nil, err } @@ -193,7 +205,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien resumeState: resumeState, } - c := &Conn{ + conn := &Conn{ rAddr: rAddr, nextConn: netctx.NewPacketConn(nextConn), handshakeConfig: handshakeConfig, @@ -214,16 +226,17 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien cancelHandshaker: func() {}, cancelHandshakeReader: func() {}, - replayProtectionWindow: uint(replayProtectionWindow), + replayProtectionWindow: uint(replayProtectionWindow), //nolint:gosec // G115 state: State{ isClient: isClient, }, } - c.setRemoteEpoch(0) - c.setLocalEpoch(0) - return c, nil + conn.setRemoteEpoch(0) + conn.setLocalEpoch(0) + + return conn, nil } // Handshake runs the client or server DTLS handshake @@ -276,7 +289,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { var initialFlight flightVal var initialFSMState handshakeState - if c.handshakeConfig.resumeState != nil { + if c.handshakeConfig.resumeState != nil { //nolint:nestif if c.state.isClient { initialFlight = flight5 } else { @@ -338,11 +351,12 @@ func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) return nil, err } } + return createConn(conn, rAddr, config, false, nil) } // Read reads data from the connection. -func (c *Conn) Read(p []byte) (n int, err error) { +func (c *Conn) Read(buff []byte) (n int, err error) { //nolint:cyclop if err := c.Handshake(); err != nil { return 0, err } @@ -363,10 +377,11 @@ func (c *Conn) Read(p []byte) (n int, err error) { } switch val := out.(type) { case ([]byte): - if len(p) < len(val) { + if len(buff) < len(val) { return 0, errBufferTooSmall } - copy(p, val) + copy(buff, val) + return len(val), nil case (error): return 0, val @@ -375,8 +390,8 @@ func (c *Conn) Read(p []byte) (n int, err error) { } } -// Write writes len(p) bytes from p to the DTLS connection -func (c *Conn) Write(p []byte) (int, error) { +// Write writes len(payload) bytes from payload to the DTLS connection. +func (c *Conn) Write(payload []byte) (int, error) { if c.isConnectionClosed() { return 0, ErrConnClosed } @@ -391,7 +406,7 @@ func (c *Conn) Write(p []byte) (int, error) { return 0, err } - return len(p), c.writePackets(c.writeDeadline, []*packet{ + return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ @@ -399,7 +414,7 @@ func (c *Conn) Write(p []byte) (int, error) { Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ - Data: p, + Data: payload, }, }, shouldWrapCID: len(c.state.remoteConnectionID) > 0, @@ -417,6 +432,7 @@ func (c *Conn) Close() error { if handshakeDone != nil { <-handshakeDone } + return err } @@ -429,10 +445,11 @@ func (c *Conn) ConnectionState() (State, bool) { if err != nil { return State{}, false } + return *stateClone, true } -// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile +// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { profile := c.state.getSRTPProtectionProfile() if profile == 0 { @@ -442,7 +459,7 @@ func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { return profile, true } -// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp +// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp. func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { if profile := c.state.getSRTPProtectionProfile(); profile == 0 { return nil, false @@ -457,26 +474,32 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { var rawPackets [][]byte - for _, p := range pkts { - if h, ok := p.record.Content.(*handshake.Handshake); ok { - handshakeRaw, err := p.record.Marshal() + for _, pkt := range pkts { + if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", - srvCliStr(c.state.isClient), h.Header.Type.String(), - p.record.Header.Epoch, h.Header.MessageSequence) - - c.handshakeCache.push(handshakeRaw[recordlayer.FixedHeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + srvCliStr(c.state.isClient), dtlsHandshake.Header.Type.String(), + pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence) + + c.handshakeCache.push( + handshakeRaw[recordlayer.FixedHeaderSize:], + pkt.record.Header.Epoch, + dtlsHandshake.Header.MessageSequence, + dtlsHandshake.Header.Type, + c.state.isClient, + ) - rawHandshakePackets, err := c.processHandshakePacket(p, h) + rawHandshakePackets, err := c.processHandshakePacket(pkt, dtlsHandshake) if err != nil { return err } rawPackets = append(rawPackets, rawHandshakePackets...) } else { - rawPacket, err := c.processPacket(p) + rawPacket, err := c.processPacket(pkt) if err != nil { return err } @@ -519,8 +542,8 @@ func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { return combinedRawPackets } -func (c *Conn) processPacket(p *packet) ([]byte, error) { - epoch := p.record.Header.Epoch +func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop + epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } @@ -531,51 +554,51 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) { // prior to allowing the sequence number to wrap. return nil, errSequenceNumberOverflow } - p.record.Header.SequenceNumber = seq + pkt.record.Header.SequenceNumber = seq var rawPacket []byte - if p.shouldWrapCID { + if pkt.shouldWrapCID { //nolint:nestif // Record must be marshaled to populate fields used in inner plaintext. - if _, err := p.record.Marshal(); err != nil { + if _, err := pkt.record.Marshal(); err != nil { return nil, err } - content, err := p.record.Content.Marshal() + content, err := pkt.record.Content.Marshal() if err != nil { return nil, err } inner := &recordlayer.InnerPlaintext{ Content: content, - RealType: p.record.Header.ContentType, + RealType: pkt.record.Header.ContentType, } rawInner, err := inner.Marshal() //nolint:govet if err != nil { return nil, err } cidHeader := &recordlayer.Header{ - Version: p.record.Header.Version, + Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, - Epoch: p.record.Header.Epoch, - ContentLen: uint16(len(rawInner)), + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, - SequenceNumber: p.record.Header.SequenceNumber, + SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } - p.record.Header = *cidHeader + pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { var err error - rawPacket, err = p.record.Marshal() + rawPacket, err = pkt.record.Marshal() if err != nil { return nil, err } } - if p.shouldEncrypt { + if pkt.shouldEncrypt { var err error - rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } @@ -584,14 +607,15 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) { return rawPacket, nil } -func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) { +//nolint:cyclop +func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) { rawPackets := make([][]byte, 0) - handshakeFragments, err := c.fragmentHandshake(h) + handshakeFragments, err := c.fragmentHandshake(dtlsHandshake) if err != nil { return nil, err } - epoch := p.record.Header.Epoch + epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } @@ -603,7 +627,7 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by } var rawPacket []byte - if p.shouldWrapCID { + if pkt.shouldWrapCID { inner := &recordlayer.InnerPlaintext{ Content: handshakeFragment, RealType: protocol.ContentTypeHandshake, @@ -614,25 +638,25 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by return nil, err } cidHeader := &recordlayer.Header{ - Version: p.record.Header.Version, + Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, - Epoch: p.record.Header.Epoch, - ContentLen: uint16(len(rawInner)), + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, - SequenceNumber: p.record.Header.SequenceNumber, + SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } - p.record.Header = *cidHeader + pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { recordlayerHeader := &recordlayer.Header{ - Version: p.record.Header.Version, - ContentType: p.record.Header.ContentType, - ContentLen: uint16(len(handshakeFragment)), - Epoch: p.record.Header.Epoch, + Version: pkt.record.Header.Version, + ContentType: pkt.record.Header.ContentType, + ContentLen: uint16(len(handshakeFragment)), //nolint:gosec // G115 + Epoch: pkt.record.Header.Epoch, SequenceNumber: seq, } @@ -641,13 +665,13 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by return nil, err } - p.record.Header = *recordlayerHeader + pkt.record.Header = *recordlayerHeader rawPacket = append(rawPacket, handshakeFragment...) } - if p.shouldEncrypt { + if pkt.shouldEncrypt { var err error - rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } @@ -659,8 +683,8 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by return rawPackets, nil } -func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { - content, err := h.Message.Marshal() +func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte, error) { + content, err := dtlsHandshake.Message.Marshal() if err != nil { return nil, err } @@ -679,11 +703,11 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { contentFragmentLen := len(contentFragment) headerFragment := &handshake.Header{ - Type: h.Header.Type, - Length: h.Header.Length, - MessageSequence: h.Header.MessageSequence, - FragmentOffset: uint32(offset), - FragmentLength: uint32(contentFragmentLen), + Type: dtlsHandshake.Header.Type, + Length: dtlsHandshake.Header.Length, + MessageSequence: dtlsHandshake.Header.MessageSequence, + FragmentOffset: uint32(offset), //nolint:gosec // G115 + FragmentLength: uint32(contentFragmentLen), //nolint:gosec // G115 } offset += contentFragmentLen @@ -703,11 +727,12 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { b := make([]byte, inboundBufferSize) + return &b }, } -func (c *Conn) readAndBuffer(ctx context.Context) error { +func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer @@ -763,6 +788,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { case <-c.fsm.Done(): } } + return nil } @@ -787,37 +813,48 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { return err } } + return nil } func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { if len(c.encryptedPackets) < maxAppDataPacketQueueSize { c.encryptedPackets = append(c.encryptedPackets, packet) + return true } + return false } -func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, bool, *alert.Alert, error) { //nolint:gocognit - h := &recordlayer.Header{} +//nolint:gocognit,gocyclo,cyclop,maintidx +func (c *Conn) handleIncomingPacket( + ctx context.Context, + buf []byte, + rAddr net.Addr, + enqueue bool, +) (bool, bool, *alert.Alert, error) { + header := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will // be parsed correctly. if len(c.state.getLocalConnectionID()) > 0 { - h.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) + header.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } - if err := h.Unmarshal(buf); err != nil { + if err := header.Unmarshal(buf); err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) + return false, false, nil, nil } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() - if h.Epoch > remoteEpoch { - if h.Epoch > remoteEpoch+1 { + if header.Epoch > remoteEpoch { + if header.Epoch > remoteEpoch+1 { c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", - h.Epoch, h.SequenceNumber, + header.Epoch, header.SequenceNumber, ) + return false, false, nil, nil } if enqueue { @@ -825,20 +862,22 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debug("received packet of next epoch, queuing packet") } } + return false, false, nil, nil } // Anti-replay protection - for len(c.state.replayDetector) <= int(h.Epoch) { + for len(c.state.replayDetector) <= int(header.Epoch) { c.state.replayDetector = append(c.state.replayDetector, replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), ) } - markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber) + markPacketAsValid, ok := c.state.replayDetector[int(header.Epoch)].Check(header.SequenceNumber) if !ok { c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", - h.Epoch, h.SequenceNumber, + header.Epoch, header.SequenceNumber, ) + return false, false, nil, nil } @@ -847,60 +886,66 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A originalCID := false // Decrypt - if h.Epoch != 0 { + if header.Epoch != 0 { //nolint:nestif if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("handshake not finished, queuing packet") } } + return false, false, nil, nil } // If a connection identifier had been negotiated and encryption is // enabled, the connection identifier MUST be sent. - if len(c.state.getLocalConnectionID()) > 0 && h.ContentType != protocol.ContentTypeConnectionID { + if len(c.state.getLocalConnectionID()) > 0 && header.ContentType != protocol.ContentTypeConnectionID { c.log.Debug("discarded packet missing connection ID after value negotiated") + return false, false, nil, nil } var err error var hdr recordlayer.Header - if h.ContentType == protocol.ContentTypeConnectionID { + if header.ContentType == protocol.ContentTypeConnectionID { hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) + return false, false, nil, nil } // If this is a connection ID record, make it look like a normal record for // further processing. - if h.ContentType == protocol.ContentTypeConnectionID { + if header.ContentType == protocol.ContentTypeConnectionID { originalCID = true ip := &recordlayer.InnerPlaintext{} - if err := ip.Unmarshal(buf[h.Size():]); err != nil { //nolint:govet + if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet c.log.Debugf("unpacking inner plaintext failed: %s", err) + return false, false, nil, nil } unpacked := &recordlayer.Header{ ContentType: ip.RealType, - ContentLen: uint16(len(ip.Content)), - Version: h.Version, - Epoch: h.Epoch, - SequenceNumber: h.SequenceNumber, + ContentLen: uint16(len(ip.Content)), //nolint:gosec // G115 + Version: header.Version, + Epoch: header.Epoch, + SequenceNumber: header.SequenceNumber, } buf, err = unpacked.Marshal() if err != nil { c.log.Debugf("converting CID record to inner plaintext failed: %s", err) + return false, false, nil, nil } buf = append(buf, ip.Content...) } // If connection ID does not match discard the packet. - if !bytes.Equal(c.state.getLocalConnectionID(), h.ConnectionID) { + if !bytes.Equal(c.state.getLocalConnectionID(), header.ConnectionID) { c.log.Debug("unexpected connection ID") + return false, false, nil, nil } } @@ -910,6 +955,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) + return false, false, nil, nil } else if isHandshake { markPacketAsValid() @@ -918,6 +964,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) + continue } c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) @@ -941,6 +988,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } _ = markPacketAsValid() + return false, false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { @@ -949,10 +997,11 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debugf("CipherSuite not initialized, queuing packet") } } + return false, false, nil, nil } - newRemoteEpoch := h.Epoch + 1 + newRemoteEpoch := header.Epoch + 1 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) if c.state.getRemoteEpoch()+1 == newRemoteEpoch { @@ -960,8 +1009,10 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A isLatestSeqNum = markPacketAsValid() } case *protocol.ApplicationData: - if h.Epoch == 0 { - return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero + if header.Epoch == 0 { + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, errApplicationDataEpochZero } isLatestSeqNum = markPacketAsValid() @@ -973,7 +1024,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } default: - return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) } // Any valid connection ID record is a candidate for updating the remote @@ -1005,6 +1058,7 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip } } } + return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ @@ -1029,10 +1083,17 @@ func (c *Conn) setHandshakeCompletedSuccessfully() { func (c *Conn) isHandshakeCompletedSuccessfully() bool { boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool }) + return boolean.bool } -func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit,contextcheck +//nolint:cyclop,gocognit,contextcheck +func (c *Conn) handshake( + ctx context.Context, + cfg *handshakeConfig, + initialFlight flightVal, + initialState handshakeState, +) error { c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) done := make(chan struct{}) @@ -1081,10 +1142,10 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh }() defer handshakeLoopsFinished.Done() for { - if err := c.readAndBuffer(ctxRead); err != nil { - var e *alertError - if errors.As(err, &e) { - if !e.IsFatalOrCloseNotify() { + if err := c.readAndBuffer(ctxRead); err != nil { //nolint:nestif + var alertErr *alertError + if errors.As(err, &alertErr) { + if !alertErr.IsFatalOrCloseNotify() { if c.isHandshakeCompletedSuccessfully() { // Pass the error to Read() select { @@ -1093,11 +1154,15 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case <-ctxRead.Done(): } } + continue // non-fatal alert must not stop read loop } } else { switch { - case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): + case errors.Is(err, context.DeadlineExceeded), + errors.Is(err, context.Canceled), + errors.Is(err, io.EOF), + errors.Is(err, net.ErrClosed): case errors.Is(err, recordlayer.ErrInvalidPacketLength): // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] @@ -1110,6 +1175,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case <-c.closed.Done(): case <-ctxRead.Done(): } + continue // non-fatal alert must not stop read loop } } @@ -1120,8 +1186,8 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh default: } - if e != nil { - if e.IsFatalOrCloseNotify() { + if alertErr != nil { + if alertErr.IsFatalOrCloseNotify() { _ = c.close(false) //nolint:contextcheck } } @@ -1129,6 +1195,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh c.log.Trace("handshake timeouts - closing underline connection") _ = c.close(false) //nolint:contextcheck } + return } } @@ -1139,11 +1206,13 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh cancelRead() cancel() handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil @@ -1157,6 +1226,7 @@ func (c *Conn) translateHandshakeCtxError(err error) error { if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { return nil } + return &HandshakeError{Err: err} } @@ -1213,15 +1283,16 @@ func (c *Conn) setRemoteEpoch(epoch uint16) { c.state.remoteEpoch.Store(epoch) } -// LocalAddr implements net.Conn.LocalAddr +// LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } -// RemoteAddr implements net.Conn.RemoteAddr +// RemoteAddr implements net.Conn.RemoteAddr. func (c *Conn) RemoteAddr() net.Addr { c.lock.RLock() defer c.lock.RUnlock() + return c.rAddr } @@ -1232,16 +1303,18 @@ func (c *Conn) sessionKey() []byte { // neither address or domain name. return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) } + return c.state.SessionID } -// SetDeadline implements net.Conn.SetDeadline +// SetDeadline implements net.Conn.SetDeadline. func (c *Conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) + return c.SetWriteDeadline(t) } -// SetReadDeadline implements net.Conn.SetReadDeadline +// SetReadDeadline implements net.Conn.SetReadDeadline. func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) // Read deadline is fully managed by this layer. @@ -1249,7 +1322,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return nil } -// SetWriteDeadline implements net.Conn.SetWriteDeadline +// SetWriteDeadline implements net.Conn.SetWriteDeadline. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline is also fully managed by this layer. diff --git a/conn_go_test.go b/conn_go_test.go index 6b05d8d76..b22e7c71e 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -21,7 +21,7 @@ import ( "github.com/pion/transport/v3/test" ) -func TestContextConfig(t *testing.T) { +func TestContextConfig(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -62,11 +62,13 @@ func TestContextConfig(t *testing.T) { "Dial": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { conn, err := Dial("udp", addr, config) if err != nil { return nil, err } + return conn, conn.HandshakeContext(ctx) }, func() { cancel() @@ -78,11 +80,13 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { conn, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) if err != nil { return nil, err } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() @@ -95,11 +99,13 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { conn, err := Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) if err != nil { return nil, err } + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() @@ -123,6 +129,7 @@ func TestContextConfig(t *testing.T) { if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck t.Errorf("Client error exp(Temporary network error) failed(%v)", err) close(done) + return } done <- struct{}{} diff --git a/conn_test.go b/conn_test.go index 9edb3aee0..47538a80a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -63,6 +63,8 @@ func TestStressDuplex(t *testing.T) { } func stressDuplex(t *testing.T) { + t.Helper() + ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) @@ -117,7 +119,7 @@ func TestRoutineLeakOnClose(t *testing.T) { // inboundLoop routine should not be leaked. } -func TestReadWriteDeadline(t *testing.T) { +func TestReadWriteDeadline(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() @@ -126,7 +128,7 @@ func TestReadWriteDeadline(t *testing.T) { report := test.CheckRoutines(t) defer report() - var e net.Error + var netErr net.Error ca, cb, err := pipeMemory() if err != nil { @@ -137,22 +139,22 @@ func TestReadWriteDeadline(t *testing.T) { t.Fatal(err) } _, werr := ca.Write(make([]byte, 100)) - if errors.As(werr, &e) { - if !e.Timeout() { + if errors.As(werr, &netErr) { + if !netErr.Timeout() { t.Error("Deadline exceeded Write must return Timeout error") } - if !e.Temporary() { //nolint:staticcheck + if !netErr.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Write must return Temporary error") } } else { t.Error("Write must return net.Error error") } _, rerr := ca.Read(make([]byte, 100)) - if errors.As(rerr, &e) { - if !e.Timeout() { + if errors.As(rerr, &netErr) { + if !netErr.Timeout() { t.Error("Deadline exceeded Read must return Timeout error") } - if !e.Temporary() { //nolint:staticcheck + if !netErr.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Read must return Temporary error") } } else { @@ -251,6 +253,7 @@ func TestSequenceNumberOverflow(t *testing.T) { func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() + return pipeConn(ca, cb) } @@ -260,33 +263,44 @@ func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { err error } - c := make(chan result) + resultCh := make(chan result) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Setup client go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + }, true) + resultCh <- result{client, err} }() // Setup server - server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + }, true) if err != nil { return nil, nil, err } // Receive client - res := <-c + res := <-resultCh if res.err != nil { _ = server.Close() + return nil, nil, res.err } return res.c, server, nil } -func testClient(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { +func testClient( + ctx context.Context, + pktConn net.PacketConn, + rAddr net.Addr, + cfg *Config, + generateCertificate bool, +) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -295,14 +309,21 @@ func testClient(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true - conn, err := Client(c, rAddr, cfg) + conn, err := Client(pktConn, rAddr, cfg) if err != nil { return nil, err } + return conn, conn.HandshakeContext(ctx) } -func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool) (*Conn, error) { +func testServer( + ctx context.Context, + c net.PacketConn, + rAddr net.Addr, + cfg *Config, + generateCertificate bool, +) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -314,6 +335,7 @@ func testServer(ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Conf if err != nil { return nil, err } + return conn, conn.HandshakeContext(ctx) } @@ -325,7 +347,7 @@ func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensio }, Content: &handshake.Handshake{ Header: handshake.Header{ - MessageSequence: uint16(sequenceNumber), + MessageSequence: uint16(sequenceNumber), //nolint:gosec // G115 }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, @@ -343,6 +365,7 @@ func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensio if _, err = ca.Write(packet); err != nil { return err } + return nil } @@ -440,9 +463,13 @@ func TestHandshakeWithInvalidRecord(t *testing.T) { } } go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{ - CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - }, true) + client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(caWithInvalidRecord), + caWithInvalidRecord.RemoteAddr(), + &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, + true, + ) clientErr <- result{client, err} }() @@ -475,7 +502,7 @@ func TestHandshakeWithInvalidRecord(t *testing.T) { } } -func TestExportKeyingMaterial(t *testing.T) { +func TestExportKeyingMaterial(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -486,7 +513,7 @@ func TestExportKeyingMaterial(t *testing.T) { expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} - c := &Conn{ + conn := &Conn{ state: State{ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, @@ -494,10 +521,10 @@ func TestExportKeyingMaterial(t *testing.T) { cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, }, } - c.setLocalEpoch(0) - c.setRemoteEpoch(0) + conn.setLocalEpoch(0) + conn.setRemoteEpoch(0) - state, ok := c.ConnectionState() + state, ok := conn.ConnectionState() if !ok { t.Fatal("ConnectionState failed") } @@ -506,8 +533,8 @@ func TestExportKeyingMaterial(t *testing.T) { t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) } - c.setLocalEpoch(1) - state, ok = c.ConnectionState() + conn.setLocalEpoch(1) + state, ok = conn.ConnectionState() if !ok { t.Fatal("ConnectionState failed") } @@ -517,7 +544,7 @@ func TestExportKeyingMaterial(t *testing.T) { } for k := range invalidKeyingLabels() { - state, ok = c.ConnectionState() + state, ok = conn.ConnectionState() if !ok { t.Fatal("ConnectionState failed") } @@ -527,7 +554,7 @@ func TestExportKeyingMaterial(t *testing.T) { } } - state, ok = c.ConnectionState() + state, ok = conn.ConnectionState() if !ok { t.Fatal("ConnectionState failed") } @@ -538,8 +565,8 @@ func TestExportKeyingMaterial(t *testing.T) { t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial) } - c.state.isClient = true - state, ok = c.ConnectionState() + conn.state.isClient = true + state, ok = conn.ConnectionState() if !ok { t.Fatal("ConnectionState failed") } @@ -551,7 +578,7 @@ func TestExportKeyingMaterial(t *testing.T) { } } -func TestPSK(t *testing.T) { +func TestPSK(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -642,7 +669,10 @@ func TestPSK(t *testing.T) { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { - return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113 + return nil, fmt.Errorf( //nolint:goerr113 + "TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", + test.ServerIdentity, hint, + ) } return []byte{0xAB, 0xC1, 0x23}, nil @@ -662,6 +692,7 @@ func TestPSK(t *testing.T) { if !bytes.Equal(test.ClientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, test.ClientIdentity, hint) } + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ServerIdentity, @@ -678,6 +709,7 @@ func TestPSK(t *testing.T) { if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) { t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err) } + return } if err != nil { @@ -690,7 +722,10 @@ func TestPSK(t *testing.T) { } actualPSKIdentityHint := state.IdentityHint if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) { - t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientIdentity, actualPSKIdentityHint) + t.Errorf( + "TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ClientIdentity, actualPSKIdentityHint, + ) } defer func() { @@ -744,7 +779,9 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false); !errors.Is(err, serverAlertError) { + if _, err := testServer( + ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false, + ); !errors.Is(err, serverAlertError) { t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) } @@ -753,8 +790,8 @@ func TestPSKHintFail(t *testing.T) { } } -// Assert that ServerKeyExchange is only sent if Identity is set on server side -func TestPSKServerKeyExchange(t *testing.T) { +// Assert that ServerKeyExchange is only sent if Identity is set on server side. +func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -827,7 +864,9 @@ func TestPSKServerKeyExchange(t *testing.T) { config.PSKIdentityHint = []byte{0xAB, 0xC1, 0x23} } - if server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false); err != nil { + if server, err := testServer( + ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false, + ); err != nil { t.Fatalf("TestPSK: Server error %v", err) } else { if err = server.Close(); err != nil { @@ -840,7 +879,10 @@ func TestPSKServerKeyExchange(t *testing.T) { } if gotServerKeyExchange != test.SetIdentity { - t.Fatalf("Mismatch between setting Identity and getting a ServerKeyExchange exp(%t) actual(%t)", test.SetIdentity, gotServerKeyExchange) + t.Fatalf( + "Mismatch between setting Identity and getting a ServerKeyExchange exp(%t) actual(%t)", + test.SetIdentity, gotServerKeyExchange, + ) } }) } @@ -879,7 +921,7 @@ func TestClientTimeout(t *testing.T) { } } -func TestSRTPConfiguration(t *testing.T) { +func TestSRTPConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -953,16 +995,23 @@ func TestSRTPConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier, + }, true) if !errors.Is(err, test.WantServerError) { - t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Errorf( + "TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } if err == nil { defer func() { @@ -970,14 +1019,17 @@ func TestSRTPConfiguration(t *testing.T) { }() } - res := <-c + res := <-resultCh if res.err == nil { defer func() { _ = res.c.Close() }() } if !errors.Is(res.err, test.WantClientError) { - t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Fatalf( + "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantClientError, res.err, + ) } if res.c == nil { return @@ -985,27 +1037,39 @@ func TestSRTPConfiguration(t *testing.T) { actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() if actualClientSRTP != test.ExpectedProfile { - t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP) + t.Errorf( + "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ExpectedProfile, actualClientSRTP, + ) } actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() if actualServerSRTP != test.ExpectedProfile { - t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) + t.Errorf( + "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ExpectedProfile, actualServerSRTP, + ) } actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() if !bytes.Equal(actualServerMKI, test.ServerSRTPMasterKeyIdentifier) { - t.Errorf("TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI) + t.Errorf( + "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI, + ) } actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() if !bytes.Equal(actualClientMKI, test.ClientSRTPMasterKeyIdentifier) { - t.Errorf("TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI) + t.Errorf( + "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI, + ) } } } -func TestClientCertificate(t *testing.T) { +func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1139,12 +1203,17 @@ func TestClientCertificate(t *testing.T) { wantErr: true, }, "RequireAndVerifyClientCert": { - clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { - if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { - return errExample - } - return nil - }}, + clientCfg: &Config{ + RootCAs: srvCAPool, + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *State) error { + if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { + return errExample + } + + return nil + }, + }, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, @@ -1153,6 +1222,7 @@ func TestClientCertificate(t *testing.T) { if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { return errExample } + return nil }, }, @@ -1218,7 +1288,9 @@ func TestClientCertificate(t *testing.T) { t.Error("Server connection state not available") } actualClientCert := state.PeerCertificates - if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { + //nolint:nestif + if tt.serverCfg.ClientAuth == RequireAnyClientCert || + tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { if actualClientCert == nil { t.Errorf("Client did not provide a certificate") } @@ -1370,16 +1442,28 @@ func TestConnectionID(t *testing.T) { }() if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) { - t.Errorf("Unexpected client local connection ID\nwant: %v\ngot:%v", tt.clientConnectionID, res.c.state.localConnectionID) + t.Errorf( + "Unexpected client local connection ID\nwant: %v\ngot:%v", + tt.clientConnectionID, res.c.state.localConnectionID, + ) } if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) { - t.Errorf("Unexpected client remote connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, res.c.state.remoteConnectionID) + t.Errorf( + "Unexpected client remote connection ID\nwant: %v\ngot:%v", + tt.serverConnectionID, res.c.state.remoteConnectionID, + ) } if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) { - t.Errorf("Unexpected server local connection ID\nwant: %v\ngot:%v", tt.serverConnectionID, server.state.localConnectionID) + t.Errorf( + "Unexpected server local connection ID\nwant: %v\ngot:%v", + tt.serverConnectionID, server.state.localConnectionID, + ) } if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) { - t.Errorf("Unexpected server remote connection ID\nwant: %v\ngot:%v", tt.clientConnectionID, server.state.remoteConnectionID) + t.Errorf( + "Unexpected server remote connection ID\nwant: %v\ngot:%v", + tt.clientConnectionID, server.state.remoteConnectionID, + ) } }) } @@ -1527,7 +1611,7 @@ func TestExtendedMasterSecret(t *testing.T) { } } -func TestServerCertificate(t *testing.T) { +func TestServerCertificate(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1564,21 +1648,32 @@ func TestServerCertificate(t *testing.T) { }, "good_ca_skip_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, - serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { - if len(chain) != 0 { - return errNotExpectedChain - } - return nil - }}, + serverCfg: &Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: RequireAnyClientCert, + VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { + if len(chain) != 0 { + return errNotExpectedChain + } + + return nil + }, + }, }, "good_ca_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, - serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { - if len(chain) == 0 { - return errExpecedChain - } - return nil - }}, + serverCfg: &Config{ + ClientCAs: caPool, + Certificates: []tls.Certificate{cert}, + ClientAuth: RequireAndVerifyClientCert, + VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { + if len(chain) == 0 { + return errExpecedChain + } + + return nil + }, + }, }, "good_ca_custom_verify_peer": { clientCfg: &Config{ @@ -1695,8 +1790,10 @@ func TestCipherSuiteConfiguration(t *testing.T) { WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, }, { - Name: "Server supports subset of client suites", - ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, + Name: "Server supports subset of client suites", + ClientCipherSuites: []CipherSuiteID{ + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: nil, WantServerError: nil, @@ -1713,33 +1810,46 @@ func TestCipherSuiteConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.ClientCipherSuites}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: test.ClientCipherSuites, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: test.ServerCipherSuites}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: test.ServerCipherSuites, + }, true) if err == nil { defer func() { _ = server.Close() }() } if !errors.Is(err, test.WantServerError) { - t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Errorf( + "TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } - res := <-c + res := <-resultCh if res.err == nil { _ = server.Close() _ = res.c.Close() } if !errors.Is(res.err, test.WantClientError) { - t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Errorf( + "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantClientError, res.err, + ) } if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite { - t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID()) + t.Errorf( + "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", + test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), + ) } }) } @@ -1773,7 +1883,7 @@ func TestCertificateAndPSKServer(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} @@ -1786,7 +1896,7 @@ func TestCertificateAndPSKServer(t *testing.T) { } client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) - c <- result{client, err} + resultCh <- result{client, err} }() config := &Config{ @@ -1805,18 +1915,21 @@ func TestCertificateAndPSKServer(t *testing.T) { t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err) } - res := <-c + res := <-resultCh if res.err == nil { _ = server.Close() _ = res.c.Close() } else { - t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err) + t.Errorf( + "TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, nil, res.err, + ) } }) } } -func TestPSKConfiguration(t *testing.T) { +func TestPSKConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1885,30 +1998,50 @@ func TestPSKConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) - c <- result{client, err} + client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(ca), + ca.RemoteAddr(), + &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, + test.ClientHasCertificate, + ) + resultCh <- result{client, err} }() - _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) + _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, + test.ServerHasCertificate, + ) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { - t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Fatalf( + "TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } } - res := <-c + res := <-resultCh if res.err != nil || test.WantClientError != nil { if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { - t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Fatalf( + "TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, + test.WantClientError, + res.err, + ) } } } } -func TestServerTimeout(t *testing.T) { +func TestServerTimeout(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2037,7 +2170,7 @@ func TestServerTimeout(t *testing.T) { } } -func TestProtocolVersionValidation(t *testing.T) { +func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2118,8 +2251,8 @@ func TestProtocolVersionValidation(t *testing.T) { }, }, } - for name, c := range serverCases { - c := c + for name, serverCase := range serverCases { + serverCase := serverCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { @@ -2137,7 +2270,13 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + config, + true, + ); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -2145,7 +2284,7 @@ func TestProtocolVersionValidation(t *testing.T) { time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) - for _, record := range c.records { + for _, record := range serverCase.records { packet, err := record.Marshal() if err != nil { t.Fatal(err) @@ -2198,9 +2337,13 @@ func TestProtocolVersionValidation(t *testing.T) { MessageSequence: 1, }, Message: &handshake.MessageServerHello{ - Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade - Random: random, - CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(), + Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade + Random: random, + CipherSuiteID: func() *uint16 { + id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + + return &id + }(), CompressionMethod: defaultCompressionMethods()[0], }, }, @@ -2208,8 +2351,8 @@ func TestProtocolVersionValidation(t *testing.T) { }, }, } - for name, c := range clientCases { - c := c + for name, clientCase := range clientCases { + clientCase := clientCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { @@ -2227,14 +2370,16 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is( + err, errUnsupportedProtocolVersion, + ) { t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() time.Sleep(50 * time.Millisecond) - for _, record := range c.records { + for _, record := range clientCase.records { if _, err := ca.Read(make([]byte, 1024)); err != nil { t.Fatal(err) } @@ -2266,7 +2411,7 @@ func TestProtocolVersionValidation(t *testing.T) { }) } -func TestMultipleHelloVerifyRequest(t *testing.T) { +func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2294,7 +2439,7 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { }, Content: &handshake.Handshake{ Header: handshake.Header{ - MessageSequence: uint16(i), + MessageSequence: uint16(i), //nolint:gosec // G115 }, Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, @@ -2359,8 +2504,8 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { } // Assert that a DTLS Server always responds with RenegotiationInfo if -// a ClientHello contained that extension or not -func TestRenegotationInfo(t *testing.T) { +// a ClientHello contained that extension or not. +func TestRenegotationInfo(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(10 * time.Second) defer lim.Stop() @@ -2397,7 +2542,13 @@ func TestRenegotationInfo(t *testing.T) { defer cancel() go func() { - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + &Config{}, + true, + ); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2418,12 +2569,12 @@ func TestRenegotationInfo(t *testing.T) { if err != nil { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err = r.Unmarshal(resp[:n]); err != nil { + record := &recordlayer.RecordLayer{} + if err = record.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } - helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) + helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } @@ -2441,11 +2592,11 @@ func TestRenegotationInfo(t *testing.T) { t.Fatal(err) } - if err := r.Unmarshal(messages[0]); err != nil { + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } @@ -2553,7 +2704,7 @@ func TestServerNameIndicationExtension(t *testing.T) { } } -func TestALPNExtension(t *testing.T) { +func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2645,7 +2796,9 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } - if _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is(err2, context.Canceled) { + if _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is( + err2, context.Canceled, + ) { if test.ExpectAlertFromServer { //nolint // Assert the error type? } else { @@ -2697,13 +2850,13 @@ func TestALPNExtension(t *testing.T) { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err := r.Unmarshal(messages[0]); err != nil { + record := &recordlayer.RecordLayer{} + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - if test.ExpectAlertFromServer { - a, ok := r.Content.(*alert.Alert) + if test.ExpectAlertFromServer { //nolint:nestif + a, ok := record.Content.(*alert.Alert) if !ok { t.Fatal("Failed to cast alert.Alert") } @@ -2712,7 +2865,7 @@ func TestALPNExtension(t *testing.T) { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) } } else { - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast handshake.MessageServerHello") } @@ -2738,7 +2891,7 @@ func TestALPNExtension(t *testing.T) { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol) } - s, err := r.Marshal() + s, err := record.Marshal() if err != nil { t.Fatal(err) } @@ -2776,8 +2929,8 @@ func TestALPNExtension(t *testing.T) { } } -// Make sure the supported_groups extension is not included in the ServerHello -func TestSupportedGroupsExtension(t *testing.T) { +// Make sure the supported_groups extension is not included in the ServerHello. +func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2792,7 +2945,9 @@ func TestSupportedGroupsExtension(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is( + err, context.Canceled, + ) { t.Error(err) } }() @@ -2818,12 +2973,12 @@ func TestSupportedGroupsExtension(t *testing.T) { if err != nil { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err = r.Unmarshal(resp[:n]); err != nil { + record := &recordlayer.RecordLayer{} + if err = record.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } - helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) + helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } @@ -2841,11 +2996,11 @@ func TestSupportedGroupsExtension(t *testing.T) { t.Fatal(err) } - if err := r.Unmarshal(messages[0]); err != nil { + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } @@ -2863,7 +3018,7 @@ func TestSupportedGroupsExtension(t *testing.T) { }) } -func TestSessionResume(t *testing.T) { +func TestSessionResume(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2885,7 +3040,9 @@ func TestSessionResume(t *testing.T) { ss := &memSessStore{} id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") - secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + secret, _ := hex.DecodeString( + "2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7", + ) s := Session{ID: id, Secret: secret} @@ -3034,7 +3191,8 @@ func (ms *memSessStore) Del(key []byte) error { // Assert that the server only uses CipherSuites with a hash+signature that matches // the certificate. As specified in rfc5246#section-7.4.3 -func TestCipherSuiteMatchesCertificateType(t *testing.T) { +// . +func TestCipherSuiteMatchesCertificateType(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -3068,7 +3226,9 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.cipherList}, false) + c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: test.cipherList, + }, false) clientErr <- err client <- c }() @@ -3113,8 +3273,8 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { } } -// Test that we return the proper certificate if we are serving multiple ServerNames on a single Server -func TestMultipleServerCertificates(t *testing.T) { +// Test that we return the proper certificate if we are serving multiple ServerNames on a single Server. +func TestMultipleServerCertificates(t *testing.T) { //nolint:cyclop fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") if err != nil { t.Fatal(err) @@ -3158,7 +3318,7 @@ func TestMultipleServerCertificates(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + clientConn, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { @@ -3175,10 +3335,12 @@ func TestMultipleServerCertificates(t *testing.T) { }, }, false) clientErr <- err - client <- c + client <- clientConn }() - if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { + if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{fooCert, barCert}, + }, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) @@ -3193,7 +3355,7 @@ func TestMultipleServerCertificates(t *testing.T) { } } -func TestEllipticCurveConfiguration(t *testing.T) { +func TestEllipticCurveConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -3227,25 +3389,39 @@ func TestEllipticCurveConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + EllipticCurves: test.ConfigCurves, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + EllipticCurves: test.ConfigCurves, + }, true) if err != nil { t.Fatalf("Server error: %v", err) } if len(test.ConfigCurves) == 0 && len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf("Failed to default Elliptic curves, expected %d, got: %d", len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves)) + t.Fatalf( + "Failed to default Elliptic curves, expected %d, got: %d", + len(test.HandshakeCurves), + len(server.fsm.cfg.ellipticCurves), + ) } if len(test.ConfigCurves) != 0 { if len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf("Failed to configure Elliptic curves, expect %d, got %d", len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves)) + t.Fatalf( + "Failed to configure Elliptic curves, expect %d, got %d", + len(test.HandshakeCurves), + len(server.fsm.cfg.ellipticCurves), + ) } for i, c := range test.ConfigCurves { if c != server.fsm.cfg.ellipticCurves[i] { @@ -3254,7 +3430,7 @@ func TestEllipticCurveConfiguration(t *testing.T) { } } - res := <-c + res := <-resultCh if res.err != nil { t.Fatalf("Client error; %v", err) } @@ -3293,6 +3469,7 @@ func TestSkipHelloVerify(t *testing.T) { }, false) if sErr != nil { t.Error(sErr) + return } buf := make([]byte, 1024) @@ -3336,6 +3513,7 @@ func (c *connWithCallback) Write(b []byte) (int, error) { if c.onWrite != nil { c.onWrite(b) } + return c.Conn.Write(b) } @@ -3360,6 +3538,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Error(err) + return } cfg := &Config{} @@ -3368,6 +3547,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) if err != nil { t.Error(err) + return } go func() { @@ -3422,7 +3602,7 @@ func TestApplicationDataQueueLimited(t *testing.T) { <-done } -func TestHelloRandom(t *testing.T) { +func TestHelloRandom(t *testing.T) { //nolint:cyclop report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -3458,6 +3638,7 @@ func TestHelloRandom(t *testing.T) { }, false) if sErr != nil { t.Error(sErr) + return } buf := make([]byte, 1024) @@ -3510,6 +3691,7 @@ func TestOnConnectionAttempt(t *testing.T) { if in == nil { t.Fatal("net.Addr is nil") //nolint: govet } + return nil }, }, true) @@ -3523,6 +3705,7 @@ func TestOnConnectionAttempt(t *testing.T) { if in == nil { t.Fatal("net.Addr is nil") //nolint: govet } + return expectedErr }, }, true); !errors.Is(err, expectedErr) { @@ -3544,7 +3727,10 @@ func TestOnConnectionAttempt(t *testing.T) { func TestFragmentBuffer_Retransmission(t *testing.T) { fragmentBuffer := newFragmentBuffer() - frag := []byte{0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01} + frag := []byte{ + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, + 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, + } if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil { t.Fatal(err) @@ -3589,10 +3775,10 @@ func TestConnectionState(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - c := make(chan error) + errorChannel := make(chan error) go func() { errC := client.HandshakeContext(ctx) - c <- errC + errorChannel <- errC }() // Setup server @@ -3604,7 +3790,7 @@ func TestConnectionState(t *testing.T) { _ = server.Close() }() - err = <-c + err = <-errorChannel if err != nil { t.Fatal(err) } diff --git a/connection_id.go b/connection_id.go index f0e94631b..c590499b4 100644 --- a/connection_id.go +++ b/connection_id.go @@ -21,6 +21,7 @@ func RandomCIDGenerator(size int) func() []byte { if _, err := rand.Read(cid); err != nil { panic(err) //nolint -- nonrecoverable } + return cid } } @@ -54,8 +55,10 @@ func cidDatagramRouter(size int) func([]byte) (string, bool) { if h.ContentType != protocol.ContentTypeConnectionID { continue } + return string(h.ConnectionID), true } + return "", false } } @@ -65,7 +68,7 @@ func cidDatagramRouter(size int) func([]byte) (string, bool) { // NOTE: a ServerHello should always be the first record in a datagram if // multiple are present, so we avoid iterating through all packets if the first // is not a ServerHello. -func cidConnIdentifier() func([]byte) (string, bool) { +func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop return func(packet []byte) (string, bool) { pkts, err := recordlayer.UnpackDatagram(packet) if err != nil || len(pkts) < 1 { @@ -96,6 +99,7 @@ func cidConnIdentifier() func([]byte) (string, bool) { return string(e.CID), true } } + return "", false } } diff --git a/connection_id_test.go b/connection_id_test.go index 8936aba98..aba5f72d0 100644 --- a/connection_id_test.go +++ b/connection_id_test.go @@ -85,7 +85,7 @@ func TestCIDDatagramRouter(t *testing.T) { Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, - ContentLen: uint16(len(inner)), + ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: cid, SequenceNumber: 1, }).Marshal() @@ -128,6 +128,7 @@ func TestCIDDatagramRouter(t *testing.T) { want: string(cid), }, "OneRecordConnectionIDAltLength": { + //nolint:lll reason: "If datagram contains one Connection ID record, but it has the wrong length we should not be able to extract it.", size: cidLen, datagram: func() []byte { @@ -135,19 +136,21 @@ func TestCIDDatagramRouter(t *testing.T) { Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, - ContentLen: uint16(len(inner)), + ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("abcd"), SequenceNumber: 1, }).Marshal() if err != nil { t.Fatal(err) } + return append(altCIDHeader, inner...) }(), ok: false, want: "", }, "MultipleRecordOneConnectionID": { + //nolint:lll reason: "If datagram contains multiple records and one is a Connection ID record, we should be able to extract it.", size: 8, datagram: append(append(appRecord, cidHeader...), inner...), @@ -155,6 +158,7 @@ func TestCIDDatagramRouter(t *testing.T) { want: string(cid), }, "MultipleRecordMultipleConnectionID": { + //nolint:lll reason: "If datagram contains multiple records and multiple are Connection ID records, we should extract the first one.", size: 8, datagram: append(append(append(appRecord, func() []byte { @@ -162,13 +166,14 @@ func TestCIDDatagramRouter(t *testing.T) { Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, - ContentLen: uint16(len(inner)), + ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("1234abcd"), SequenceNumber: 1, }).Marshal() if err != nil { t.Fatal(err) } + return append(altCIDHeader, inner...) }()...), cidHeader...), inner...), ok: true, @@ -257,12 +262,14 @@ func TestCIDConnIdentifier(t *testing.T) { want: string(cid), }, "MultipleRecordFirstServerHello": { + //nolint:lll reason: "If datagram contains multiple records and the first is a ServerHello record, we should be able to extract an identifier.", datagram: append(sh, appRecord...), ok: true, want: string(cid), }, "MultipleRecordNotFirstServerHello": { + //nolint:lll reason: "If datagram contains multiple records and the first is not a ServerHello record, we should not be able to extract an identifier.", datagram: append(appRecord, sh...), ok: false, diff --git a/crypto.go b/crypto.go index 25b2a1f9f..04b39d442 100644 --- a/crypto.go +++ b/crypto.go @@ -43,7 +43,12 @@ func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve el // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 -func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { +func generateKeySignature( + clientRandom, serverRandom, publicKey []byte, + namedCurve elliptic.Curve, + privateKey crypto.PrivateKey, + hashAlgorithm hash.Algorithm, +) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) switch p := privateKey.(type) { case ed25519.PrivateKey: @@ -51,16 +56,23 @@ func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCur return p.Sign(rand.Reader, msg, crypto.Hash(0)) case *ecdsa.PrivateKey: hashed := hashAlgorithm.Digest(msg) + return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PrivateKey: hashed := hashAlgorithm.Digest(msg) + return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented } -func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte) error { //nolint:dupl +//nolint:dupl,cyclop +func verifyKeySignature( + message, remoteKeySignature []byte, + hashAlgorithm hash.Algorithm, + rawCertificates [][]byte, +) error { if len(rawCertificates) == 0 { return errLengthMismatch } @@ -69,11 +81,12 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A return err } - switch p := certificate.PublicKey.(type) { + switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: - if ok := ed25519.Verify(p, message, remoteKeySignature); !ok { + if ok := ed25519.Verify(pubKey, message, remoteKeySignature); !ok { return errKeySignatureMismatch } + return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} @@ -84,15 +97,17 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A return errInvalidECDSASignature } hashed := hashAlgorithm.Digest(message) - if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) { + if !ecdsa.Verify(pubKey, hashed, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } + return nil case *rsa.PublicKey: hashed := hashAlgorithm.Digest(message) - if rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { return errKeySignatureMismatch } + return nil } @@ -107,7 +122,11 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 -func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { +func generateCertificateVerify( + handshakeBodies []byte, + privateKey crypto.PrivateKey, + hashAlgorithm hash.Algorithm, +) ([]byte, error) { if p, ok := privateKey.(ed25519.PrivateKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over @@ -127,7 +146,13 @@ func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.Private return nil, errInvalidSignatureAlgorithm } -func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl +//nolint:dupl,cyclop +func verifyCertificateVerify( + handshakeBodies []byte, + hashAlgorithm hash.Algorithm, + remoteKeySignature []byte, + rawCertificates [][]byte, +) error { if len(rawCertificates) == 0 { return errLengthMismatch } @@ -136,11 +161,12 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorith return err } - switch p := certificate.PublicKey.(type) { + switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: - if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok { + if ok := ed25519.Verify(pubKey, handshakeBodies, remoteKeySignature); !ok { return errKeySignatureMismatch } + return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} @@ -151,15 +177,17 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorith return errInvalidECDSASignature } hash := hashAlgorithm.Digest(handshakeBodies) - if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) { + if !ecdsa.Verify(pubKey, hash, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } + return nil case *rsa.PublicKey: hash := hashAlgorithm.Digest(handshakeBodies) - if rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { return errKeySignatureMismatch } + return nil } @@ -179,6 +207,7 @@ func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { } certs = append(certs, cert) } + return certs, nil } @@ -197,10 +226,15 @@ func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [] Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } + return certificate[0].Verify(opts) } -func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) { +func verifyServerCert( + rawCertificates [][]byte, + roots *x509.CertPool, + serverName string, +) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err @@ -215,5 +249,6 @@ func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName DNSName: serverName, Intermediates: intermediateCAPool, } + return certificate[0].Verify(opts) } diff --git a/crypto_test.go b/crypto_test.go index e3f572eb5..249ca2cdc 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -51,21 +51,34 @@ func TestGenerateKeySignature(t *testing.T) { t.Error(err) } - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - publicKey := []byte{0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15} + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + publicKey := []byte{ + 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, + 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, + } expectedSignature := []byte{ - 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, - 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, - 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, - 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, - 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, - 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, - 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, - 0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36, - 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, - 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, - 0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86, + 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, + 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, + 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, + 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, + 0xa2, 0x81, 0xd0, 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, + 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, + 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, + 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, + 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, + 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, + 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, + 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, + 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, + 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, + 0x5c, 0x36, 0x75, 0x86, } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256) diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index 5be97eb4c..e70130cdc 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -21,10 +21,9 @@ const ( lossyTestTimeout = 30 * time.Second ) -/* -DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments -*/ -func TestPionE2ELossy(t *testing.T) { +// DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments + +func TestPionE2ELossy(t *testing.T) { //nolint:cyclop // Check for leaking routines report := transportTest.CheckRoutines(t) defer report() @@ -213,20 +212,32 @@ func TestPionE2ELossy(t *testing.T) { select { case serverResult := <-serverDone: if serverResult.err != nil { - t.Errorf("Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) + t.Errorf( + "Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, serverResult.err, + ) + return } serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { - t.Errorf("Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) + t.Errorf( + "Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, clientResult.err, + ) + return } clientConn = clientResult.dtlsConn case <-testTimer.C: - t.Errorf("Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss) + t.Errorf( + "Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", + clientConn != nil, serverConn != nil, chosenLoss, + ) + return case <-time.After(10 * time.Millisecond): } diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 06f9b78fb..f02d1507f 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -41,11 +41,11 @@ var ( errHookAPLNFailed = errors.New("hook failed to modify APLN extension") ) -func randomPort(t testing.TB) int { - t.Helper() +func randomPort(tb testing.TB) int { + tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil { - t.Fatalf("failed to pickPort: %v", err) + tb.Fatalf("failed to pickPort: %v", err) } defer func() { _ = conn.Close() @@ -54,7 +54,8 @@ func randomPort(t testing.TB) int { case *net.UDPAddr: return addr.Port default: - t.Fatalf("unknown addr type %T", addr) + tb.Fatalf("unknown addr type %T", addr) + return 0 } } @@ -65,6 +66,7 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter n, err := conn.Read(buffer) if err != nil { errChan <- err + return } @@ -77,6 +79,7 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter break } else if _, err := conn.Write([]byte(testMessage)); err != nil { errChan <- err + break } @@ -85,7 +88,7 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter } type comm struct { - ctx context.Context + ctx context.Context //nolint:containedctx clientConfig, serverConfig *dtls.Config serverPort int messageRecvCount *uint64 // Counter to make sure both sides got a message @@ -104,9 +107,15 @@ type comm struct { server func(*comm) } -func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm { +func newComm( + ctx context.Context, + clientConfig, serverConfig *dtls.Config, + serverPort int, + server, client func(*comm), +) *comm { messageRecvCount := uint64(0) - c := &comm{ + + com := &comm{ ctx: ctx, clientConfig: clientConfig, serverConfig: serverConfig, @@ -123,10 +132,13 @@ func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serve server: server, client: client, } - return c + + return com } -func (c *comm) assert(t *testing.T) { +func (c *comm) assert(t *testing.T) { //nolint:cyclop + t.Helper() + // DTLS Client go c.client(c) @@ -182,7 +194,9 @@ func (c *comm) assert(t *testing.T) { }() } -func (c *comm) cleanup(t *testing.T) { +func (c *comm) cleanup(t *testing.T) { //nolint:cyclop + t.Helper() + clientDone, serverDone := false, false for { select { @@ -208,7 +222,7 @@ func (c *comm) cleanup(t *testing.T) { } } -func clientPion(c *comm) { +func clientPion(c *comm) { //nolint:varnamelen select { case <-c.serverReady: // OK @@ -225,11 +239,13 @@ func clientPion(c *comm) { ) if err != nil { c.errChan <- err + return } if err := conn.HandshakeContext(c.ctx); err != nil { c.errChan <- err + return } @@ -240,7 +256,7 @@ func clientPion(c *comm) { close(c.clientDone) } -func serverPion(c *comm) { +func serverPion(c *comm) { //nolint:varnamelen c.serverMutex.Lock() defer c.serverMutex.Unlock() @@ -251,12 +267,14 @@ func serverPion(c *comm) { ) if err != nil { c.errChan <- err + return } c.serverReady <- struct{}{} c.serverConn, err = c.serverListener.Accept() if err != nil { c.errChan <- err + return } @@ -264,6 +282,7 @@ func serverPion(c *comm) { if ok { if err := dtlsConn.HandshakeContext(c.ctx); err != nil { c.errChan <- err + return } } @@ -281,13 +300,12 @@ func withConnectionIDGenerator(g func() []byte) dtlsConfOpts { } } -/* - Simple DTLS Client/Server can communicate - - Assert that you can send messages both ways - - Assert that Close() on both ends work - - Assert that no Goroutines are leaked -*/ +// Simple DTLS Client/Server can communicate +// - Assert that you can send messages both ways +// - Assert that Close() on both ends work +// - Assert that no Goroutines are leaked func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -326,6 +344,8 @@ func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsCon } func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -363,6 +383,8 @@ func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtls } func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -402,6 +424,8 @@ func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsConfO } func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -446,6 +470,8 @@ func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ... } func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -494,6 +520,8 @@ func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm) } func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -542,6 +570,8 @@ func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), } func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -590,6 +620,8 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), op } func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -617,11 +649,13 @@ func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), if s.CipherSuiteID != modifiedCipher { return errHookCiphersFailed } + return nil }, CipherSuites: supportedList, ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message { ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)} + return &ch }, InsecureSkipVerify: true, @@ -645,6 +679,8 @@ func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), } func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -670,6 +706,7 @@ func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), if s.NegotiatedProtocol != apln { return errHookAPLNFailed } + return nil }, CipherSuites: supportedList, @@ -683,6 +720,7 @@ func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), sh.Extensions = append(sh.Extensions, &extension.ALPN{ ProtocolNameList: []string{apln}, }) + return &sh }, InsecureSkipVerify: true, diff --git a/errors.go b/errors.go index f03fb11c0..b7f93b7b4 100644 --- a/errors.go +++ b/errors.go @@ -15,65 +15,133 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/alert" ) -// Typed errors +// Typed errors. var ( ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113 errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 - errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113 - errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113 - errReservedExportKeyingMaterial = &TemporaryError{Err: errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113 - errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113 - errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113 - - errCertificateVerifyNoCertificate = &FatalError{Err: errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113 - errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113 - errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113 - errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113 - errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113 - errClientRequiredButNoServerEMS = &FatalError{Err: errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113 - errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113 - errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113 - errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113 - errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113 - errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113 - errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113 - errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 - errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113 - errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113 - errNoAvailableCipherSuites = &FatalError{Err: errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113 - errNoAvailablePSKCipherSuite = &FatalError{Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite")} //nolint:goerr113 - errNoAvailableCertificateCipherSuite = &FatalError{Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite")} //nolint:goerr113 - errNoAvailableSignatureSchemes = &FatalError{Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113 - errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113 - errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113 - errNoSupportedEllipticCurves = &FatalError{Err: errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113 - errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 - errPSKAndIdentityMustBeSetForClient = &FatalError{Err: errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113 - errRequestedButNoSRTPExtension = &FatalError{Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113 - errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113 - errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113 - errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113 - errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} //nolint:goerr113 - - errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113 - errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113 - errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113 - errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 - errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 - errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113 - errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113 - errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} //nolint:goerr113 + //nolint:goerr113 + errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} + //nolint:goerr113 + errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} + //nolint:goerr113 + errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} + //nolint:goerr113 + errReservedExportKeyingMaterial = &TemporaryError{ + Err: errors.New("ExportKeyingMaterial can not be used with a reserved label"), + } + //nolint:goerr113 + errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} + //nolint:goerr113 + errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} + + //nolint:goerr113 + errCertificateVerifyNoCertificate = &FatalError{ + Err: errors.New("client sent certificate verify but we have no certificate to verify"), + } + //nolint:goerr113 + errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} + //nolint:goerr113 + errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} + //nolint:goerr113 + errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} + //nolint:goerr113 + errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} + //nolint:goerr113 + errClientRequiredButNoServerEMS = &FatalError{ + Err: errors.New("client required Extended Master Secret extension, but server does not support it"), + } + //nolint:goerr113 + errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} + //nolint:goerr113 + errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} + //nolint:goerr113 + errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} + //nolint:goerr113 + errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} + //nolint:goerr113 + errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} + //nolint:goerr113 + errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} + //nolint:goerr113 + errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} + //nolint:goerr113 + errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} + //nolint:goerr113 + errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} + //nolint:goerr113 + errNoAvailableCipherSuites = &FatalError{ + Err: errors.New("connection can not be created, no CipherSuites satisfy this Config"), + } + //nolint:goerr113 + errNoAvailablePSKCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite"), + } + //nolint:goerr113 + errNoAvailableCertificateCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite"), + } + //nolint:goerr113 + errNoAvailableSignatureSchemes = &FatalError{ + Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config"), + } + //nolint:goerr113 + errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} + //nolint:goerr113 + errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} + //nolint:goerr113 + errNoSupportedEllipticCurves = &FatalError{ + Err: errors.New("client requested zero or more elliptic curves that are not supported by the server"), + } + //nolint:goerr113 + errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} + //nolint:goerr113 + errPSKAndIdentityMustBeSetForClient = &FatalError{ + Err: errors.New("PSK and PSK Identity Hint must both be set for client"), + } + //nolint:goerr113 + errRequestedButNoSRTPExtension = &FatalError{ + Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension"), + } + //nolint:goerr113 + errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} + //nolint:goerr113 + errServerRequiredButNoClientEMS = &FatalError{ + Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it"), + } + //nolint:goerr113 + errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} + //nolint:goerr113 + errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} + + //nolint:goerr113 + errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} + //nolint:goerr113 + errKeySignatureGenerateUnimplemented = &InternalError{ + Err: errors.New("unable to generate key signature, unimplemented"), + } + //nolint:goerr113 + errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} + //nolint:goerr113 + errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} + //nolint:goerr113 + errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} + //nolint:goerr113 + errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} + //nolint:goerr113 + errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} + //nolint:goerr113 + errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} ) // FatalError indicates that the DTLS connection is no longer available. // It is mainly caused by wrong configuration of server or client. type FatalError = protocol.FatalError -// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available. +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError = protocol.InternalError @@ -100,10 +168,11 @@ func (e *invalidCipherSuiteError) Is(err error) bool { if errors.As(err, &other) { return e.id == other.id } + return false } -// errAlert wraps DTLS alert notification as an error +// errAlert wraps DTLS alert notification as an error. type alertError struct { *alert.Alert } @@ -121,6 +190,7 @@ func (e *alertError) Is(err error) bool { if errors.As(err, &other) { return e.Level == other.Level && e.Description == other.Description } + return false } @@ -138,7 +208,7 @@ func netError(err error) error { se *os.SyscallError ) - if errors.As(err, &opError) { + if errors.As(err, &opError) { //nolint:nestif if errors.As(opError, &se) { if se.Timeout() { return &TimeoutError{Err: err} diff --git a/errors_test.go b/errors_test.go index 05c2c2745..db3bffc59 100644 --- a/errors_test.go +++ b/errors_test.go @@ -65,21 +65,21 @@ func TestErrorNetError(t *testing.T) { {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false}, {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true}, } - for _, c := range cases { - c := c - t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { + for _, testCase := range cases { + testCase := testCase + t.Run(fmt.Sprintf("%T", testCase.err), func(t *testing.T) { var ne net.Error - if !errors.As(c.err, &ne) { - t.Fatalf("%T doesn't implement net.Error", c.err) + if !errors.As(testCase.err, &ne) { + t.Fatalf("%T doesn't implement net.Error", testCase.err) } - if ne.Timeout() != c.timeout { - t.Errorf("%T.Timeout() should be %v", c.err, c.timeout) + if ne.Timeout() != testCase.timeout { + t.Errorf("%T.Timeout() should be %v", testCase.err, testCase.timeout) } - if ne.Temporary() != c.temporary { //nolint:staticcheck - t.Errorf("%T.Temporary() should be %v", c.err, c.temporary) + if ne.Temporary() != testCase.temporary { //nolint:staticcheck + t.Errorf("%T.Temporary() should be %v", testCase.err, testCase.temporary) } - if ne.Error() != c.str { - t.Errorf("%T.Error() should be %v", c.err, c.str) + if ne.Error() != testCase.str { + t.Errorf("%T.Error() should be %v", testCase.err, testCase.str) } }) } diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go index b52631f1a..15f316137 100644 --- a/examples/dial/cid/main.go +++ b/examples/dial/cid/main.go @@ -26,6 +26,7 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Client"), @@ -45,6 +46,7 @@ func main() { if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) + return } diff --git a/examples/dial/psk/main.go b/examples/dial/psk/main.go index 112826bbc..2847d40c7 100644 --- a/examples/dial/psk/main.go +++ b/examples/dial/psk/main.go @@ -26,6 +26,7 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{}, @@ -44,6 +45,7 @@ func main() { if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) + return } diff --git a/examples/dial/selfsign/main.go b/examples/dial/selfsign/main.go index 7d86adf13..66ff80a70 100644 --- a/examples/dial/selfsign/main.go +++ b/examples/dial/selfsign/main.go @@ -46,6 +46,7 @@ func main() { if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) + return } diff --git a/examples/dial/verify/main.go b/examples/dial/verify/main.go index e572aec2d..32992a9d1 100644 --- a/examples/dial/verify/main.go +++ b/examples/dial/verify/main.go @@ -53,6 +53,7 @@ func main() { if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) + return } diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go index 7418d895b..2c0b41ee6 100644 --- a/examples/listen/cid/main.go +++ b/examples/listen/cid/main.go @@ -26,6 +26,7 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Server"), diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index 041b1fb4e..4098e0dfe 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -26,6 +26,7 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Server"), diff --git a/examples/util/hub.go b/examples/util/hub.go index 19e33db0c..4e656a21f 100644 --- a/examples/util/hub.go +++ b/examples/util/hub.go @@ -12,18 +12,18 @@ import ( "sync" ) -// Hub is a helper to handle one to many chat +// Hub is a helper to handle one to many chat. type Hub struct { conns map[string]net.Conn lock sync.RWMutex } -// NewHub builds a new hub +// NewHub builds a new hub. func NewHub() *Hub { return &Hub{conns: make(map[string]net.Conn)} } -// Register adds a new conn to the Hub +// Register adds a new conn to the Hub. func (h *Hub) Register(conn net.Conn) { fmt.Printf("Connected to %s\n", conn.RemoteAddr()) h.lock.Lock() @@ -40,6 +40,7 @@ func (h *Hub) readLoop(conn net.Conn) { n, err := conn.Read(b) if err != nil { h.unregister(conn) + return } fmt.Printf("Got message: %s\n", string(b[:n])) @@ -69,7 +70,7 @@ func (h *Hub) broadcast(msg []byte) { } } -// Chat starts the stdin readloop to dispatch messages to the hub +// Chat starts the stdin readloop to dispatch messages to the hub. func (h *Hub) Chat() { reader := bufio.NewReader(os.Stdin) for { diff --git a/examples/util/util.go b/examples/util/util.go index e15a2f360..e3a85cb0e 100644 --- a/examples/util/util.go +++ b/examples/util/util.go @@ -24,7 +24,7 @@ var ( errNoCertificateFound = errors.New("no certificate found, unable to load certificates") ) -// Chat simulates a simple text chat session over the connection +// Chat simulates a simple text chat session over the connection. func Chat(conn io.ReadWriter) { go func() { b := make([]byte, bufSize) @@ -51,7 +51,7 @@ func Chat(conn io.ReadWriter) { } } -// Check is a helper to throw errors in the examples +// Check is a helper to throw errors in the examples. func Check(err error) { var netError net.Error if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck @@ -62,12 +62,12 @@ func Check(err error) { } } -// LoadKeyAndCertificate reads certificates or key from file +// LoadKeyAndCertificate reads certificates or key from file. func LoadKeyAndCertificate(keyPath string, certificatePath string) (tls.Certificate, error) { return tls.LoadX509KeyPair(certificatePath, keyPath) } -// LoadCertificate Load/read certificate(s) from file +// LoadCertificate Load/read certificate(s) from file. func LoadCertificate(path string) (*tls.Certificate, error) { rawData, err := os.ReadFile(filepath.Clean(path)) if err != nil { diff --git a/flight.go b/flight.go index cfa58c574..7ecc9489d 100644 --- a/flight.go +++ b/flight.go @@ -70,7 +70,7 @@ const ( flight6 ) -func (f flightVal) String() string { +func (f flightVal) String() string { //nolint:cyclop switch f { case flight0: return "Flight 0" diff --git a/flight0handler.go b/flight0handler.go index 7bb528f1a..ce6ad2031 100644 --- a/flight0handler.go +++ b/flight0handler.go @@ -14,7 +14,14 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/handshake" ) -func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +//nolint:cyclop +func flight0Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) @@ -55,32 +62,32 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak } for _, val := range clientHello.Extensions { - switch e := val.(type) { + switch ext := val.(type) { case *extension.SupportedEllipticCurves: - if len(e.EllipticCurves) == 0 { + if len(ext.EllipticCurves) == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves } - state.namedCurve = e.EllipticCurves[0] + state.namedCurve = ext.EllipticCurves[0] case *extension.UseSRTP: - profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + profile, ok := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) - state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ServerName: - state.serverName = e.ServerName // remote server name + state.serverName = ext.ServerName // remote server name case *extension.ALPN: - state.peerSupportedProtocols = e.ProtocolNameList + state.peerSupportedProtocols = ext.ProtocolNameList case *extension.ConnectionID: // Only set connection ID to be sent if server supports connection // IDs. if cfg.connectionIDGenerator != nil { - state.remoteConnectionID = e.CID + state.remoteConnectionID = ext.CID } } } @@ -112,7 +119,12 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) } -func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) { +func handleHelloResume( + sessionID []byte, + state *State, + cfg *handshakeConfig, + next flightVal, +) (flightVal, *alert.Alert, error) { if len(sessionID) > 0 && cfg.sessionStore != nil { if s, err := cfg.sessionStore.Get(sessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err @@ -132,10 +144,16 @@ func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, nex return flight4b, nil, nil } } + return next, nil, nil } -func flight0Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight0Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { // Initialize if !cfg.insecureSkipHelloVerify { state.cookie = make([]byte, cookieLength) diff --git a/flight1handler.go b/flight1handler.go index 60215c084..6c55a6430 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -14,7 +14,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight1Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { // HelloVerifyRequest can be skipped by the server, // so allow ServerHello during flight1 also seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, @@ -29,7 +35,7 @@ func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handsh if _, ok := msgs[handshake.TypeServerHello]; ok { // Flight1 and flight2 were skipped. // Parse as flight3. - return flight3Parse(ctx, c, state, cache, cfg) + return flight3Parse(ctx, conn, state, cache, cfg) } if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { @@ -40,13 +46,20 @@ func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq + return flight3, nil, nil } return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } -func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +//nolint:cyclop +func flight1Generate( + conn flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) @@ -74,6 +87,7 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha for _, c := range cfg.localCipherSuites { if c.ECC() { setEllipticCurveCryptographyClientHelloExtensions = true + break } } @@ -113,7 +127,7 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha if cfg.sessionStore != nil { cfg.log.Tracef("[handshake] try to resume session") - if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil { + if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] get saved session: %x", s.ID) diff --git a/flight1handler_test.go b/flight1handler_test.go index f6dc46c3e..457ee413b 100644 --- a/flight1handler_test.go +++ b/flight1handler_test.go @@ -34,13 +34,14 @@ type flight1TestMockCipherSuite struct { func (f *flight1TestMockCipherSuite) IsInitialized() bool { f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + return true } // When "server hello" arrives later than "certificate", // "server key exchange", "certificate request", "server hello done", -// is it normal for the flight1Parse method to handle it -func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { +// is it normal for the flight1Parse method to handle it. +func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() diff --git a/flight2handler.go b/flight2handler.go index 0af457c35..8d50befba 100644 --- a/flight2handler.go +++ b/flight2handler.go @@ -13,7 +13,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight2Parse( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) @@ -41,11 +47,18 @@ func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handsh if !bytes.Equal(state.cookie, clientHello.Cookie) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch } + return flight4, nil, nil } -func flight2Generate(_ flightConn, state *State, _ *handshakeCache, _ *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight2Generate( + _ flightConn, + state *State, + _ *handshakeCache, + _ *handshakeConfig, +) ([]*packet, *alert.Alert, error) { state.handshakeSendSequence = 0 + return []*packet{ { record: &recordlayer.RecordLayer{ diff --git a/flight3handler.go b/flight3handler.go index f27c01a7e..7301e34b6 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -17,7 +17,14 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,gocyclo,maintidx,cyclop +func flight3Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { // Clients may receive multiple HelloVerifyRequest messages with different cookies. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 @@ -33,6 +40,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq + return flight3, nil, nil } } @@ -45,33 +53,36 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, nil, nil } - if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { - if !h.Version.Equal(protocol.Version1_2) { + if serverHelloMsg, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { //nolint:nestif + if !serverHelloMsg.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } - for _, v := range h.Extensions { - switch e := v.(type) { + for _, v := range serverHelloMsg.Extensions { + switch ext := v.(type) { case *extension.UseSRTP: - profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + profile, found := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) - state.remoteSRTPMasterKeyIdentifier = e.MasterKeyIdentifier + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ALPN: - if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling - return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error? + if len(ext.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling + return 0, &alert.Alert{ + Level: alert.Fatal, + Description: alert.InternalError, + }, extension.ErrALPNInvalidFormat // Meh, internal error? } - state.NegotiatedProtocol = e.ProtocolNameList[0] + state.NegotiatedProtocol = ext.ProtocolNameList[0] case *extension.ConnectionID: // Only set connection ID to be sent if client supports connection // IDs. if cfg.connectionIDGenerator != nil { - state.remoteConnectionID = e.CID + state.remoteConnectionID = ext.CID } } } @@ -88,7 +99,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } - remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites) + remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*serverHelloMsg.CipherSuiteID), cfg.customCipherSuites) if remoteCipherSuite == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } @@ -99,11 +110,11 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cipherSuite = selectedCipherSuite - state.remoteRandom = h.Random + state.remoteRandom = serverHelloMsg.Random cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) - if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) { - return handleResumption(ctx, c, state, cache, cfg) + if len(serverHelloMsg.SessionID) > 0 && bytes.Equal(state.SessionID, serverHelloMsg.SessionID) { + return handleResumption(ctx, conn, state, cache, cfg) } if len(state.SessionID) > 0 { @@ -116,7 +127,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh if cfg.sessionStore == nil { state.SessionID = []byte{} } else { - state.SessionID = h.SessionID + state.SessionID = serverHelloMsg.SessionID } state.masterSecret = []byte{} @@ -148,7 +159,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { - alertPtr, err := handleServerKeyExchange(c, state, cfg, h) + alertPtr, err := handleServerKeyExchange(conn, state, cfg, h) if err != nil { return 0, alertPtr, err } @@ -162,7 +173,13 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh return flight5, nil, nil } -func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func handleResumption( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -203,25 +220,36 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha return flight5b, nil, nil } -func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { +//nolint:cyclop +func handleServerKeyExchange( + _ flightConn, + state *State, + cfg *handshakeConfig, + keyExchangeMessage *handshake.MessageServerKeyExchange, +) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } - if cfg.localPSKCallback != nil { + if cfg.localPSKCallback != nil { //nolint:nestif var psk []byte - if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { + if psk, err = cfg.localPSKCallback(keyExchangeMessage.IdentityHint); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - state.IdentityHint = h.IdentityHint + state.IdentityHint = keyExchangeMessage.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case types.KeyExchangeAlgorithmPsk: state.preMasterSecret = prf.PSKPreMasterSecret(psk) case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): - if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -229,11 +257,15 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } } else { - if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + if state.preMasterSecret, err = prf.PreMasterSecret( + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } @@ -241,7 +273,12 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h return nil, nil //nolint:nilnil } -func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight3Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, diff --git a/flight3handler_test.go b/flight3handler_test.go index e801be00a..af7374d6d 100644 --- a/flight3handler_test.go +++ b/flight3handler_test.go @@ -18,8 +18,8 @@ import ( "github.com/pion/transport/v3/test" ) -// Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available -func TestSupportedEllipticCurves(t *testing.T) { +// Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available. +func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -51,7 +51,7 @@ func TestSupportedEllipticCurves(t *testing.T) { h := &handshake.Handshake{} _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) - if h.Header.Type == handshake.TypeClientHello { + if h.Header.Type == handshake.TypeClientHello { //nolint:nestif clientHello := &handshake.MessageClientHello{} msg, err := h.Message.Marshal() @@ -78,7 +78,13 @@ func TestSupportedEllipticCurves(t *testing.T) { EllipticCurves: expectedCurves, } - if client, err := testClient(ctx, dtlsnet.PacketConnFromConn(caAnalyzer), caAnalyzer.RemoteAddr(), conf, false); err != nil { + if client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(caAnalyzer), + caAnalyzer.RemoteAddr(), + conf, + false, + ); err != nil { clientErr <- err } else { clientErr <- client.Close() //nolint diff --git a/flight4bhandler.go b/flight4bhandler.go index d87a1feee..681533b0e 100644 --- a/flight4bhandler.go +++ b/flight4bhandler.go @@ -15,7 +15,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight4bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) @@ -47,7 +53,13 @@ func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handsha return flight4b, nil, nil } -func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +//nolint:cyclop +func flight4bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var pkts []*packet extensions := []extension.Extension{&extension.RenegotiationInfo{ @@ -95,7 +107,7 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha serverHello = handshake.Handshake{Message: serverHelloMessage} } - serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) + serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) //nolint:gosec // G115 if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( diff --git a/flight4handler.go b/flight4handler.go index 7e4ae12f1..75e2e8b1a 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -20,7 +20,14 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,gocyclo,lll,cyclop,maintidx +func flight4Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, @@ -47,7 +54,8 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh state.SessionID = nil } - if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify { + //nolint:nestif + if verify, hasVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasVerify { if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate } @@ -66,8 +74,9 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { - if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + if ss.Hash == verify.HashAlgorithm && ss.Signature == verify.SignatureAlgorithm { validSignatureScheme = true + break } } @@ -75,7 +84,12 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } - if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil { + if err := verifyCertificateVerify( + plainText, + verify.HashAlgorithm, + verify.Signature, + state.PeerCertificates, + ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate @@ -99,7 +113,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, nil, nil } - if !state.cipherSuite.IsInitialized() { + if !state.cipherSuite.IsInitialized() { //nolint:nestif serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() @@ -115,14 +129,23 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh case CipherSuiteKeyExchangeAlgorithmPsk: preMasterSecret = prf.PSKPreMasterSecret(psk) case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): - if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + if preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite } } else { - preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + preMasterSecret, err = prf.PreMasterSecret( + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } @@ -140,7 +163,12 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } else { - state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + state.masterSecret, err = prf.MasterSecret( + preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -164,7 +192,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh } // Now, encrypted packets can be handled - if err := c.handleQueuedPackets(ctx); err != nil { + if err := conn.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -181,7 +209,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } - if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { //nolint:nestif if cfg.verifyConnection != nil { stateClone, err := state.clone() if err != nil { @@ -191,6 +219,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } + return flight6, nil, nil } @@ -226,7 +255,13 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return flight6, nil, nil } -func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,cyclop,maintidx +func flight4Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} @@ -337,7 +372,14 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } - signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash) + signature, err := generateKeySignature( + clientRandom[:], + serverRandom[:], + state.localKeypair.PublicKey, + state.namedCurve, + certificate.PrivateKey, + signatureHashAlgo.Hash, + ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -369,7 +411,9 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha // an appropriate certificate to give to us. var certificateAuthorities [][]byte if cfg.clientCAs != nil { - // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty. + // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR + // because cert does not come from SystemCertPool and it's ok if certificate + // authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } @@ -396,7 +440,8 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha }, }) } - case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): + case cfg.localPSKIdentityHint != nil || + state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): // To help the client in selecting which identity to use, the server // can provide a "PSK identity hint" in the ServerKeyExchange message. // If no hint is provided and cipher suite doesn't use elliptic curve, diff --git a/flight4handler_test.go b/flight4handler_test.go index 20e9d2253..458292b69 100644 --- a/flight4handler_test.go +++ b/flight4handler_test.go @@ -40,13 +40,14 @@ type flight4TestMockCipherSuite struct { func (f *flight4TestMockCipherSuite) IsInitialized() bool { f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + return true } // Assert that if a Client sends a certificate they // must also send a CertificateVerify message. // The flight4handler must not interact with the CipherSuite -// if the CertificateVerify is missing +// if the CertificateVerify is missing. func TestFlight4_Process_CertificateVerify(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) @@ -156,6 +157,7 @@ func TestFlight4_CertificateRequestHook(t *testing.T) { clientAuth: 1, certificateRequestMessageHook: func(mcr handshake.MessageCertificateRequest) handshake.Message { mcr.SignatureHashAlgorithms = []signaturehash.Algorithm{} + return &mcr }, } @@ -166,7 +168,7 @@ func TestFlight4_CertificateRequestHook(t *testing.T) { } for _, p := range pkts { - if h, ok := p.record.Content.(*handshake.Handshake); ok { + if h, ok := p.record.Content.(*handshake.Handshake); ok { //nolint:nestif if h.Message.Type() == handshake.TypeCertificateRequest { mcr := &handshake.MessageCertificateRequest{} msg, err := h.Message.Marshal() diff --git a/flight5bhandler.go b/flight5bhandler.go index 27a05cc21..db6de367c 100644 --- a/flight5bhandler.go +++ b/flight5bhandler.go @@ -13,7 +13,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight5bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight5bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) @@ -30,7 +36,12 @@ func flight5bParse(_ context.Context, _ flightConn, state *State, cache *handsha return flight5b, nil, nil } -func flight5bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit +func flight5bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { //nolint:gocognit var pkts []*packet pkts = append(pkts, diff --git a/flight5handler.go b/flight5handler.go index 7e940cdc9..95f2466c5 100644 --- a/flight5handler.go +++ b/flight5handler.go @@ -17,7 +17,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight5Parse( + _ context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) @@ -57,7 +63,7 @@ func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshak Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) - if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil { + if err := cfg.sessionStore.Set(conn.sessionKey(), s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } @@ -65,10 +71,16 @@ func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshak return flight5, nil, nil } -func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,cyclop,maintidx +func flight5Generate( + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var privateKey crypto.PrivateKey var pkts []*packet - if state.remoteRequestedCertificate { + if state.remoteRequestedCertificate { //nolint:nestif _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) if !ok { @@ -135,7 +147,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // handshakeMessageServerKeyExchange is optional for PSK if len(serverKeyExchangeData) == 0 { - alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{}) + alertPtr, err := handleServerKeyExchange(conn, state, cfg, &handshake.MessageServerKeyExchange{}) if err != nil { return nil, alertPtr, err } @@ -158,7 +170,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // Append not-yet-sent packets merged := []byte{} - seqPred := uint16(state.handshakeSendSequence) + seqPred := uint16(state.handshakeSendSequence) //nolint:gosec // G115 for _, p := range pkts { h, ok := p.record.Content.(*handshake.Handshake) if !ok { @@ -205,7 +217,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han } state.localCertificatesVerify = certVerify - p := &packet{ + pkt := &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, @@ -219,9 +231,9 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han }, }, } - pkts = append(pkts, p) + pkts = append(pkts, pkt) - h, ok := p.record.Content.(*handshake.Handshake) + h, ok := pkt.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } @@ -259,7 +271,11 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han ) var err error - state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc()) + state.localVerifyData, err = prf.VerifyDataClient( + state.masterSecret, + append(plainText, merged...), + state.cipherSuite.HashFunc(), + ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -286,7 +302,14 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han return pkts, nil, nil } -func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,cyclop +func initializeCipherSuite( + state *State, + cache *handshakeCache, + cfg *handshakeConfig, + handshakeKeyExchange *handshake.MessageServerKeyExchange, + sendingPlainText []byte, +) (*alert.Alert, error) { if state.cipherSuite.IsInitialized() { return nil, nil //nolint } @@ -308,18 +331,24 @@ func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCo return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } else { - state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + state.masterSecret, err = prf.MasterSecret( + state.preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } - if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { //nolint:nestif // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { - if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + if ss.Hash == handshakeKeyExchange.HashAlgorithm && ss.Signature == handshakeKeyExchange.SignatureAlgorithm { validSignatureScheme = true + break } } @@ -327,8 +356,19 @@ func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCo return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } - expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve) - if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil { + expectedMsg := valueKeyMessage( + clientRandom[:], + serverRandom[:], + handshakeKeyExchange.PublicKey, + handshakeKeyExchange.NamedCurve, + ) + if err = verifyKeySignature( + expectedMsg, + handshakeKeyExchange. + Signature, + handshakeKeyExchange.HashAlgorithm, + state.PeerCertificates, + ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate diff --git a/flight6handler.go b/flight6handler.go index 576dc551b..d7828e749 100644 --- a/flight6handler.go +++ b/flight6handler.go @@ -13,7 +13,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight6Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight6Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) @@ -30,7 +36,12 @@ func flight6Parse(_ context.Context, _ flightConn, state *State, cache *handshak return flight6, nil, nil } -func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight6Generate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var pkts []*packet pkts = append(pkts, @@ -82,5 +93,6 @@ func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *han resetLocalSequenceNumber: true, }, ) + return pkts, nil, nil } diff --git a/flighthandler.go b/flighthandler.go index 651ff17e0..b90cebd3b 100644 --- a/flighthandler.go +++ b/flighthandler.go @@ -9,13 +9,19 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/alert" ) -// Parse received handshakes and return next flightVal -type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert.Alert, error) +// Parse received handshakes and return next flightVal. +type flightParser func( + context.Context, + flightConn, + *State, + *handshakeCache, + *handshakeConfig, +) (flightVal, *alert.Alert, error) -// Generate flights +// Generate flights. type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) -func (f flightVal) getFlightParser() (flightParser, error) { +func (f flightVal) getFlightParser() (flightParser, error) { //nolint:cyclop switch f { case flight0: return flight0Parse, nil @@ -40,7 +46,7 @@ func (f flightVal) getFlightParser() (flightParser, error) { } } -func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { +func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { //nolint:cyclop switch f { case flight0: return flight0Generate, true, nil diff --git a/fragment_buffer.go b/fragment_buffer.go index 37223ab07..497d97107 100644 --- a/fragment_buffer.go +++ b/fragment_buffer.go @@ -9,7 +9,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// 2 megabytes +// 2 megabytes. const fragmentBufferMaxSize = 2000000 type fragment struct { @@ -29,7 +29,7 @@ func newFragmentBuffer() *fragmentBuffer { return &fragmentBuffer{cache: map[uint16][]*fragment{}} } -// current total size of buffer +// current total size of buffer. func (f *fragmentBuffer) size() int { size := 0 for i := range f.cache { @@ -37,12 +37,13 @@ func (f *fragmentBuffer) size() int { size += len(f.cache[i][j].data) } } + return size } // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled -// when an error returns it is fatal, and the DTLS connection should be stopped +// when an error returns it is fatal, and the DTLS connection should be stopped. func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { if f.size()+len(buf) >= fragmentBufferMaxSize { return false, false, errFragmentBufferOverflow @@ -64,7 +65,8 @@ func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err e } // Fragment is a retransmission. We have already assembled it before successfully - isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber + isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && + frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} @@ -107,9 +109,11 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { } rawMessage = append(f.data, rawMessage...) + return true } } + return false } @@ -131,5 +135,6 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { delete(f.cache, f.currentMessageSequenceNumber) f.currentMessageSequenceNumber++ + return append(rawHeader, rawMessage...), messageEpoch } diff --git a/fragment_buffer_test.go b/fragment_buffer_test.go index 2b2f62c7e..9e842b0a2 100644 --- a/fragment_buffer_test.go +++ b/fragment_buffer_test.go @@ -19,7 +19,10 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Single Fragment", In: [][]byte{ - {0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, + { + 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, @@ -29,7 +32,10 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Single Fragment Epoch 3", In: [][]byte{ - {0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, + { + 0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, @@ -39,24 +45,48 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Multiple Fragments", In: [][]byte{ - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + }, }, Expected: [][]byte{ - {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, + { + 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + }, }, Epoch: 0, }, { Name: "Multiple Unordered Fragments", In: [][]byte{ - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, + }, }, Expected: [][]byte{ - {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, + { + 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + }, }, Epoch: 0, }, @@ -122,7 +152,10 @@ func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits - if _, _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { + if _, _, err := fragmentBuffer.push([]byte{ + 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }); err != nil { t.Fatal(err) } diff --git a/handshake_cache.go b/handshake_cache.go index 285c71331..95f20953f 100644 --- a/handshake_cache.go +++ b/handshake_cache.go @@ -49,7 +49,7 @@ func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ ha // returns a list handshakes that match the requested rules // the list will contain null entries for rules that can't be satisfied -// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies) +// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies). func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { h.mu.Lock() defer h.mu.Unlock() @@ -72,15 +72,21 @@ func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCache } // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. -func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) { +// +//nolint:cyclop +func (h *handshakeCache) fullPullMap( + startSeq int, + cipherSuite CipherSuite, + rules ...handshakeCachePullRule, +) (int, map[handshake.Type]handshake.Message, bool) { h.mu.Lock() defer h.mu.Unlock() ci := make(map[handshake.Type]*handshakeCacheItem) - for _, r := range rules { + for _, rule := range rules { var item *handshakeCacheItem for _, c := range h.cache { - if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { + if c.typ == rule.typ && c.isClient == rule.isClient && c.epoch == rule.epoch { switch { case item == nil: item = c @@ -89,18 +95,18 @@ func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rule } } } - if !r.optional && item == nil { + if !rule.optional && item == nil { // Missing mandatory message. return startSeq, nil, false } - ci[r.typ] = item + ci[rule.typ] = item } out := make(map[handshake.Type]handshake.Message) seq := startSeq ok := false for _, r := range rules { - t := r.typ - i := ci[t] + typ := r.typ + i := ci[typ] if i == nil { continue } @@ -114,21 +120,22 @@ func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rule if err := rawHandshake.Unmarshal(i.data); err != nil { return startSeq, nil, false } - if uint16(seq) != rawHandshake.Header.MessageSequence { + if uint16(seq) != rawHandshake.Header.MessageSequence { //nolint:gosec // G115 // There is a gap. Some messages are not arrived. return startSeq, nil, false } seq++ ok = true - out[t] = rawHandshake.Message + out[typ] = rawHandshake.Message } if !ok { return seq, nil, false } + return seq, out, true } -// pullAndMerge calls pull and then merges the results, ignoring any null entries +// pullAndMerge calls pull and then merges the results, ignoring any null entries. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged := []byte{} @@ -137,6 +144,7 @@ func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged = append(merged, p.data...) } } + return merged } diff --git a/handshake_cache_test.go b/handshake_cache_test.go index 647a3f35e..b655ac166 100644 --- a/handshake_cache_test.go +++ b/handshake_cache_test.go @@ -144,7 +144,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeServerHelloDone, false, 0, 4, []byte{0x04}}, {handshake.TypeClientKeyExchange, true, 0, 5, []byte{0x05}}, }, - Expected: []byte{0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43}, + Expected: []byte{ + 0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, + 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43, + }, }, { Name: "Handshake With Client Cert Request", @@ -157,7 +160,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, { Name: "Handshake Ignores after ClientKeyExchange", @@ -173,7 +179,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, { Name: "Handshake Ignores wrong epoch", @@ -193,7 +202,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, } { h := newHandshakeCache() diff --git a/handshake_test.go b/handshake_test.go index 4cadb610f..8c97d20b2 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -30,7 +30,10 @@ func TestHandshakeMessage(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: handshake.Random{ GMTUnixTime: time.Unix(3056586332, 0), - RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b}, + RandomBytes: [28]byte{ + 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, + 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, + }, }, SessionID: []byte{}, Cookie: []byte{}, diff --git a/handshaker.go b/handshaker.go index 5e16808a9..e8b09b6c4 100644 --- a/handshaker.go +++ b/handshaker.go @@ -162,6 +162,7 @@ func srvCliStr(isClient bool) string { if isClient { return "client" } + return "server" } @@ -179,7 +180,7 @@ func newHandshakeFSM( } } -func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error { +func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) @@ -192,13 +193,13 @@ func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState hands var err error switch state { case handshakePreparing: - state, err = s.prepare(ctx, c) + state, err = s.prepare(ctx, conn) case handshakeSending: - state, err = s.send(ctx, c) + state, err = s.send(ctx, conn) case handshakeWaiting: - state, err = s.wait(ctx, c) + state, err = s.wait(ctx, conn) case handshakeFinished: - state, err = s.finish(ctx, c) + state, err = s.finish(ctx, conn) default: return errInvalidFSMTransition } @@ -212,24 +213,24 @@ func (s *handshakeFSM) Done() <-chan struct{} { return s.closed } -func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( - a *alert.Alert - err error - pkts []*packet + dtlsAlert *alert.Alert + err error + pkts []*packet ) gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() if errFlight != nil { err = errFlight - a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} + dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} } else { - pkts, a, err = gen(c, s.state, s.cache, s.cfg) + pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) s.retransmit = retransmit } - if a != nil { - if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil { + if dtlsAlert != nil { + if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { if err != nil { err = alertErr } @@ -248,14 +249,15 @@ func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeStat nextEpoch = p.record.Header.Epoch } if h, ok := p.record.Content.(*handshake.Handshake); ok { - h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) + h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 s.state.handshakeSendSequence++ } } if epoch != nextEpoch { s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) - c.setLocalEpoch(nextEpoch) + conn.setLocalEpoch(nextEpoch) } + return handshakeSending, nil } @@ -268,32 +270,35 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, if s.currentFlight.isLastSendFlight() { return handshakeFinished, nil } + return handshakeWaiting, nil } -func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit +func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { - if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { + if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { return handshakeErrored, alertErr } + return handshakeErrored, errFlight } retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { - case state := <-c.recvHandshake(): + case state := <-conn.recvHandshake(): if state.isRetransmit { close(state.done) + return handshakeSending, nil } - nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) + nextFlight, alert, err := parse(ctx, conn, s.state, s.cache, s.cfg) s.retransmitInterval = s.cfg.initialRetransmitInterval close(state.done) if alert != nil { - if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if alertErr := conn.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } @@ -305,11 +310,17 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, if nextFlight == 0 { break } - s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String()) + s.cfg.log.Tracef( + "[handshake:%s] %s -> %s", + srvCliStr(s.state.isClient), + s.currentFlight.String(), + nextFlight.String(), + ) if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } s.currentFlight = nextFlight + return handshakePreparing, nil case <-retransmitTimer.C: @@ -326,9 +337,11 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, if s.retransmitInterval > time.Second*60 { s.retransmitInterval = time.Second * 60 } + return handshakeSending, nil case <-ctx.Done(): s.retransmitInterval = s.cfg.initialRetransmitInterval + return handshakeErrored, ctx.Err() } } diff --git a/handshaker_test.go b/handshaker_test.go index de82a571a..88e69ee0c 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -44,7 +44,7 @@ func TestWriteKeyLog(t *testing.T) { cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) } -func TestHandshaker(t *testing.T) { +func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -84,6 +84,7 @@ func TestHandshaker(t *testing.T) { cntClientHelloNoCookie++ } } + return true }, } @@ -96,17 +97,26 @@ func TestHandshaker(t *testing.T) { } if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { cntHelloVerifyRequest++ + return cntHelloVerifyRequest > helloVerifyDrop } + return true }, } report := func(t *testing.T) { + t.Helper() + if cntHelloVerifyRequest != helloVerifyDrop+1 { - t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) + t.Errorf( + "Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", + helloVerifyDrop+1, + cntHelloVerifyRequest, + ) } if cntClientHelloNoCookie != cntHelloVerifyRequest { + ///nolint:lll t.Errorf( "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", cntHelloVerifyRequest, cntClientHelloNoCookie, @@ -132,6 +142,7 @@ func TestHandshaker(t *testing.T) { if _, ok := h.Message.(*handshake.MessageFinished); ok { cntClientFinished++ } + return true }, } @@ -145,11 +156,14 @@ func TestHandshaker(t *testing.T) { if _, ok := h.Message.(*handshake.MessageFinished); ok { cntServerFinished++ } + return true }, } report := func(t *testing.T) { + t.Helper() + if cntClientFinished != 1 { t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished) } @@ -184,6 +198,7 @@ func TestHandshaker(t *testing.T) { cntClientFinished++ } } + return true }, Delay: 0, @@ -206,6 +221,7 @@ func TestHandshaker(t *testing.T) { cntServerFinished++ } } + return true }, Delay: 1000 * time.Millisecond, @@ -216,8 +232,11 @@ func TestHandshaker(t *testing.T) { } report := func(t *testing.T) { - // with one second server delay and 100 ms retransmit (+ exponential backoff), there should be close to 4 `Finished` from client - // using a range of 3 - 5 for checking + t.Helper() + + // with one second server delay and 100 ms retransmit (+ exponential backoff), + // there should be close to 4 `Finished` from client + // using a range of 3 - 5 for checking. if cntClientFinished < 3 || cntClientFinished > 5 { t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 3, 5, cntClientFinished) } @@ -226,7 +245,11 @@ func TestHandshaker(t *testing.T) { } // there should be no `Finished` last retransmit from client if cntClientFinishedLastRetransmit != 0 { - t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit) + t.Errorf( + "Number of client finished last retransmit is wrong, expected: %d times, got: %d times", + 0, + cntClientFinishedLastRetransmit, + ) } if cntServerFinished < 1 { t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished) @@ -234,9 +257,14 @@ func TestHandshaker(t *testing.T) { if !isServerFinished { t.Errorf("Server is not finished") } - // there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`. + // there should be `Finished` last retransmit from server. + // Because of slow server, client would have sent several `Finished`. if cntServerFinishedLastRetransmit < 1 { - t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit) + t.Errorf( + "Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", + 1, + cntServerFinishedLastRetransmit, + ) } } @@ -346,11 +374,16 @@ type TestEndpoint struct { FinishWait time.Duration } -func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) { +func flightTestPipe( + ctx context.Context, + clientEndpoint TestEndpoint, + serverEndpoint TestEndpoint, +) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() chA := make(chan recvHandshakeState) chB := make(chan recvHandshakeState) + return &flightTestConn{ handshakeCache: ca, otherEndCache: cb, @@ -399,30 +432,41 @@ func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error { time.Sleep(c.delay) - for _, p := range pkts { - if c.filter != nil && !c.filter(p) { + for _, pkt := range pkts { + if c.filter != nil && !c.filter(pkt) { continue } - if h, ok := p.record.Content.(*handshake.Handshake); ok { - handshakeRaw, err := p.record.Marshal() + if handshake, ok := pkt.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } - c.handshakeCache.push(handshakeRaw[recordlayer.FixedHeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + c.handshakeCache.push( + handshakeRaw[recordlayer.FixedHeaderSize:], + pkt.record.Header.Epoch, + handshake.Header.MessageSequence, + handshake.Header.Type, + c.state.isClient, + ) - content, err := h.Message.Marshal() + content, err := handshake.Message.Marshal() if err != nil { return err } - h.Header.Length = uint32(len(content)) - h.Header.FragmentLength = uint32(len(content)) - hdr, err := h.Header.Marshal() + handshake.Header.Length = uint32(len(content)) //nolint:gosec // G115 + handshake.Header.FragmentLength = uint32(len(content)) //nolint:gosec // G115 + hdr, err := handshake.Header.Marshal() if err != nil { return err } c.otherEndCache.push( - append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + append(hdr, content...), + pkt.record.Header.Epoch, + handshake.Header.MessageSequence, + handshake.Header.Type, + c.state.isClient, + ) } } go func() { diff --git a/internal/ciphersuite/aes_128_ccm.go b/internal/ciphersuite/aes_128_ccm.go index 0877c2c18..9805f36e8 100644 --- a/internal/ciphersuite/aes_128_ccm.go +++ b/internal/ciphersuite/aes_128_ccm.go @@ -8,12 +8,19 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// Aes128Ccm is a base class used by multiple AES-CCM Ciphers +// Aes128Ccm is a base class used by multiple AES-CCM Ciphers. type Aes128Ccm struct { AesCcm } -func newAes128Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, cryptoCCMTagLen ciphersuite.CCMTagLen, keyExchangeAlgorithm KeyExchangeAlgorithm, ecc bool) *Aes128Ccm { +func newAes128Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes128Ccm { return &Aes128Ccm{ AesCcm: AesCcm{ clientCertificateType: clientCertificateType, @@ -26,8 +33,9 @@ func newAes128Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, } } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *Aes128Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const prfKeyLen = 16 + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) } diff --git a/internal/ciphersuite/aes_256_ccm.go b/internal/ciphersuite/aes_256_ccm.go index bbdf06d81..58d5e0cee 100644 --- a/internal/ciphersuite/aes_256_ccm.go +++ b/internal/ciphersuite/aes_256_ccm.go @@ -8,12 +8,19 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// Aes256Ccm is a base class used by multiple AES-CCM Ciphers +// Aes256Ccm is a base class used by multiple AES-CCM Ciphers. type Aes256Ccm struct { AesCcm } -func newAes256Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, cryptoCCMTagLen ciphersuite.CCMTagLen, keyExchangeAlgorithm KeyExchangeAlgorithm, ecc bool) *Aes256Ccm { +func newAes256Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes256Ccm { return &Aes256Ccm{ AesCcm: AesCcm{ clientCertificateType: clientCertificateType, @@ -26,8 +33,9 @@ func newAes256Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, } } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *Aes256Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const prfKeyLen = 32 + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) } diff --git a/internal/ciphersuite/aes_ccm.go b/internal/ciphersuite/aes_ccm.go index 54eafcf80..ddda1c8e7 100644 --- a/internal/ciphersuite/aes_ccm.go +++ b/internal/ciphersuite/aes_ccm.go @@ -15,7 +15,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// AesCcm is a base class used by multiple AES-CCM Ciphers +// AesCcm is a base class used by multiple AES-CCM Ciphers. type AesCcm struct { ccm atomic.Value // *cryptoCCM clientCertificateType clientcertificate.Type @@ -26,12 +26,12 @@ type AesCcm struct { ecc bool } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *AesCcm) CertificateType() clientcertificate.Type { return c.clientCertificateType } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *AesCcm) ID() ID { return c.id } @@ -40,59 +40,66 @@ func (c *AesCcm) String() string { return c.id.String() } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *AesCcm) ECC() bool { return c.ecc } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *AesCcm) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return c.keyExchangeAlgorithm } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *AesCcm) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *AesCcm) AuthenticationType() AuthenticationType { if c.psk { return AuthenticationTypePreSharedKey } + return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *AesCcm) IsInitialized() bool { return c.ccm.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *AesCcm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfKeyLen int) error { const ( prfMacLen = 0 prfIvLen = 4 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } var ccm *ciphersuite.CCM if isClient { - ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV) + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV, + ) } else { - ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV, + ) } c.ccm.Store(ccm) return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) if !ok { @@ -102,7 +109,7 @@ func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, erro return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer +// Decrypt decrypts a single TLS RecordLayer. func (c *AesCcm) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) if !ok { diff --git a/internal/ciphersuite/ciphersuite.go b/internal/ciphersuite/ciphersuite.go index 4778be72a..27b7d57ce 100644 --- a/internal/ciphersuite/ciphersuite.go +++ b/internal/ciphersuite/ciphersuite.go @@ -1,7 +1,8 @@ // SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT -// Package ciphersuite provides TLS Ciphers as registered with the IANA https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4 +// Package ciphersuite provides TLS Ciphers as registered with the IANA +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4 package ciphersuite import ( @@ -12,12 +13,13 @@ import ( "github.com/pion/dtls/v3/pkg/protocol" ) -var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} //nolint:goerr113 +//nolint:goerr113 +var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} -// ID is an ID for our supported CipherSuites +// ID is an ID for our supported CipherSuites. type ID uint16 -func (i ID) String() string { +func (i ID) String() string { //nolint:cyclop switch i { case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM" @@ -52,19 +54,19 @@ func (i ID) String() string { } } -// Supported Cipher Suites +// Supported Cipher Suites. const ( - // AES-128-CCM + // AES-128-CCM. TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:revive,stylecheck - // AES-128-GCM-SHA256 + // AES-128-GCM-SHA256. TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:revive,stylecheck - // AES-256-CBC-SHA + // AES-256-CBC-SHA. TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:revive,stylecheck @@ -77,10 +79,10 @@ const ( TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:revive,stylecheck ) -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. type AuthenticationType = types.AuthenticationType -// AuthenticationType Enums +// AuthenticationType Enums. const ( AuthenticationTypeCertificate AuthenticationType = types.AuthenticationTypeCertificate AuthenticationTypePreSharedKey AuthenticationType = types.AuthenticationTypePreSharedKey @@ -90,7 +92,7 @@ const ( // KeyExchangeAlgorithm controls what exchange algorithm was chosen. type KeyExchangeAlgorithm = types.KeyExchangeAlgorithm -// KeyExchangeAlgorithm Bitmask +// KeyExchangeAlgorithm Bitmask. const ( KeyExchangeAlgorithmNone KeyExchangeAlgorithm = types.KeyExchangeAlgorithmNone KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = types.KeyExchangeAlgorithmPsk diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go index e55290799..04a6ca40d 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go @@ -8,7 +8,14 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher +// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher. func NewTLSEcdheEcdsaWithAes128Ccm() *Aes128Ccm { - return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, ciphersuite.CCMTagLength, KeyExchangeAlgorithmEcdhe, true) + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM, + false, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmEcdhe, + true, + ) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go index a423a13f1..38a166fad 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go @@ -8,7 +8,14 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite +// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite. func NewTLSEcdheEcdsaWithAes128Ccm8() *Aes128Ccm { - return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmEcdhe, true) + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, + false, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmEcdhe, + true, + ) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go index 9f7b27788..f47a67497 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go @@ -15,27 +15,27 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheEcdsaWithAes128GcmSha256 struct { gcm atomic.Value // *cryptoGCM } -// CertificateType returns what type of certficate this CipherSuite exchanges +// CertificateType returns what type of certficate this CipherSuite exchanges. func (c *TLSEcdheEcdsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.ECDSASign } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdheEcdsaWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmEcdhe } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdheEcdsaWithAes128GcmSha256) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes128GcmSha256) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 } @@ -44,24 +44,31 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes128GcmSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdheEcdsaWithAes128GcmSha256) AuthenticationType() AuthenticationType { return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdheEcdsaWithAes128GcmSha256) IsInitialized() bool { return c.gcm.Load() != nil } -func (c *TLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfMacLen, prfKeyLen, prfIvLen int, hashFunc func() hash.Hash) error { - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc) +func (c *TLSEcdheEcdsaWithAes128GcmSha256) init( + masterSecret, clientRandom, serverRandom []byte, + isClient bool, + prfMacLen, prfKeyLen, prfIvLen int, + hashFunc func() hash.Hash, +) error { + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc, + ) if err != nil { return err } @@ -73,10 +80,11 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serv gcm, err = ciphersuite.NewGCM(keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) } c.gcm.Store(gcm) + return err } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 0 @@ -87,7 +95,7 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serv return c.init(masterSecret, clientRandom, serverRandom, isClient, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) if !ok { @@ -97,7 +105,7 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer +// Decrypt decrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) if !ok { diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go index 87f5ef395..6eeb91811 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go @@ -16,27 +16,27 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite +// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite. type TLSEcdheEcdsaWithAes256CbcSha struct { cbc atomic.Value // *cryptoCBC } -// CertificateType returns what type of certficate this CipherSuite exchanges +// CertificateType returns what type of certficate this CipherSuite exchanges. func (c *TLSEcdheEcdsaWithAes256CbcSha) CertificateType() clientcertificate.Type { return clientcertificate.ECDSASign } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdheEcdsaWithAes256CbcSha) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmEcdhe } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdheEcdsaWithAes256CbcSha) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes256CbcSha) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA } @@ -45,23 +45,23 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes256CbcSha) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdheEcdsaWithAes256CbcSha) AuthenticationType() AuthenticationType { return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdheEcdsaWithAes256CbcSha) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 20 @@ -69,7 +69,9 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverR prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -93,7 +95,7 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverR return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { @@ -103,7 +105,7 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, ra return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer +// Decrypt decrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go index 2a3cfa4f5..bf6f6c444 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go @@ -8,12 +8,12 @@ import ( "hash" ) -// TLSEcdheEcdsaWithAes256GcmSha384 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheEcdsaWithAes256GcmSha384 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheEcdsaWithAes256GcmSha384 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes256GcmSha384) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 } @@ -22,12 +22,12 @@ func (c *TLSEcdheEcdsaWithAes256GcmSha384) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes256GcmSha384) HashFunc() func() hash.Hash { return sha512.New384 } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes256GcmSha384) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 0 diff --git a/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go b/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go index 87b80f421..24f51e1eb 100644 --- a/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go @@ -15,7 +15,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. type TLSEcdhePskWithAes128CbcSha256 struct { cbc atomic.Value // *cryptoCBC } @@ -25,22 +25,22 @@ func NewTLSEcdhePskWithAes128CbcSha256() *TLSEcdhePskWithAes128CbcSha256 { return &TLSEcdhePskWithAes128CbcSha256{} } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdhePskWithAes128CbcSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdhePskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return (KeyExchangeAlgorithmPsk | KeyExchangeAlgorithmEcdhe) } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdhePskWithAes128CbcSha256) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdhePskWithAes128CbcSha256) ID() ID { return TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 } @@ -49,23 +49,23 @@ func (c *TLSEcdhePskWithAes128CbcSha256) String() string { return "TLS-ECDHE-PSK-WITH-AES-128-CBC-SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdhePskWithAes128CbcSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdhePskWithAes128CbcSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdhePskWithAes128CbcSha256) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 32 @@ -73,7 +73,9 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, server prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -97,7 +99,7 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, server return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { // !c.isInitialized() @@ -107,7 +109,7 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, r return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer +// Decrypt decrypts a single TLS RecordLayer. func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { // !c.isInitialized() diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go index 5ac17ed97..b78969111 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go @@ -5,17 +5,17 @@ package ciphersuite import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheRsaWithAes128GcmSha256 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes128GcmSha256) ID() ID { return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 } diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go index 545c2e46f..deb20dd94 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go @@ -5,17 +5,17 @@ package ciphersuite import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite +// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite. type TLSEcdheRsaWithAes256CbcSha struct { TLSEcdheEcdsaWithAes256CbcSha } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes256CbcSha) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes256CbcSha) ID() ID { return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA } diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go index 5750cc386..f7d7049a8 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go @@ -5,17 +5,17 @@ package ciphersuite import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes256GcmSha384 implements the TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuite +// TLSEcdheRsaWithAes256GcmSha384 implements the TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuite. type TLSEcdheRsaWithAes256GcmSha384 struct { TLSEcdheEcdsaWithAes256GcmSha384 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes256GcmSha384) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes256GcmSha384) ID() ID { return TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go b/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go index dc485c46b..32507cdc3 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go @@ -15,27 +15,27 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. type TLSPskWithAes128CbcSha256 struct { cbc atomic.Value // *cryptoCBC } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSPskWithAes128CbcSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSPskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmPsk } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSPskWithAes128CbcSha256) ECC() bool { return false } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSPskWithAes128CbcSha256) ID() ID { return TLS_PSK_WITH_AES_128_CBC_SHA256 } @@ -44,23 +44,23 @@ func (c *TLSPskWithAes128CbcSha256) String() string { return "TLS_PSK_WITH_AES_128_CBC_SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSPskWithAes128CbcSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSPskWithAes128CbcSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSPskWithAes128CbcSha256) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 32 @@ -68,7 +68,9 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -92,7 +94,7 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { @@ -102,7 +104,7 @@ func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw [] return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer +// Decrypt decrypts a single TLS RecordLayer. func (c *TLSPskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { diff --git a/internal/ciphersuite/tls_psk_with_aes_128_ccm.go b/internal/ciphersuite/tls_psk_with_aes_128_ccm.go index 6344f11fb..0b802fc8c 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_ccm.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_ccm.go @@ -8,7 +8,14 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite +// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite. func NewTLSPskWithAes128Ccm() *Aes128Ccm { - return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM, true, ciphersuite.CCMTagLength, KeyExchangeAlgorithmPsk, false) + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM, + true, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go b/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go index 4b0827533..c6bf6dc59 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go @@ -8,7 +8,14 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite +// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite. func NewTLSPskWithAes128Ccm8() *Aes128Ccm { - return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM_8, true, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmPsk, false) + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go index 3a5e7e753..bc50d562b 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go @@ -5,22 +5,22 @@ package ciphersuite import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSPskWithAes128GcmSha256 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSPskWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSPskWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmPsk } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSPskWithAes128GcmSha256) ID() ID { return TLS_PSK_WITH_AES_128_GCM_SHA256 } @@ -29,7 +29,7 @@ func (c *TLSPskWithAes128GcmSha256) String() string { return "TLS_PSK_WITH_AES_128_GCM_SHA256" } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSPskWithAes128GcmSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } diff --git a/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go b/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go index 211bdae03..771a1d42e 100644 --- a/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go +++ b/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go @@ -8,7 +8,14 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes256Ccm8 returns the TLS_PSK_WITH_AES_256_CCM_8 CipherSuite +// NewTLSPskWithAes256Ccm8 returns the TLS_PSK_WITH_AES_256_CCM_8 CipherSuite. func NewTLSPskWithAes256Ccm8() *Aes256Ccm { - return newAes256Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_256_CCM_8, true, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmPsk, false) + return newAes256Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_256_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/types/authentication_type.go b/internal/ciphersuite/types/authentication_type.go index 2da21e642..09681cec5 100644 --- a/internal/ciphersuite/types/authentication_type.go +++ b/internal/ciphersuite/types/authentication_type.go @@ -3,10 +3,10 @@ package types -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. type AuthenticationType int -// AuthenticationType Enums +// AuthenticationType Enums. const ( AuthenticationTypeCertificate AuthenticationType = iota + 1 AuthenticationTypePreSharedKey diff --git a/internal/ciphersuite/types/key_exchange_algorithm.go b/internal/ciphersuite/types/key_exchange_algorithm.go index c2c39113a..5b59f2410 100644 --- a/internal/ciphersuite/types/key_exchange_algorithm.go +++ b/internal/ciphersuite/types/key_exchange_algorithm.go @@ -7,7 +7,7 @@ package types // KeyExchangeAlgorithm controls what exchange algorithm was chosen. type KeyExchangeAlgorithm int -// KeyExchangeAlgorithm Bitmask +// KeyExchangeAlgorithm Bitmask. const ( KeyExchangeAlgorithmNone KeyExchangeAlgorithm = 0 KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = iota << 1 diff --git a/internal/closer/closer.go b/internal/closer/closer.go index bfa171cda..a1c25f379 100644 --- a/internal/closer/closer.go +++ b/internal/closer/closer.go @@ -8,41 +8,43 @@ import ( "context" ) -// Closer allows for each signaling a channel for shutdown +// Closer allows for each signaling a channel for shutdown. type Closer struct { - ctx context.Context + ctx context.Context //nolint:containedctx closeFunc func() } -// NewCloser creates a new instance of Closer +// NewCloser creates a new instance of Closer. func NewCloser() *Closer { ctx, closeFunc := context.WithCancel(context.Background()) + return &Closer{ ctx: ctx, closeFunc: closeFunc, } } -// NewCloserWithParent creates a new instance of Closer with a parent context +// NewCloserWithParent creates a new instance of Closer with a parent context. func NewCloserWithParent(ctx context.Context) *Closer { ctx, closeFunc := context.WithCancel(ctx) + return &Closer{ ctx: ctx, closeFunc: closeFunc, } } -// Done returns a channel signaling when it is done +// Done returns a channel signaling when it is done. func (c *Closer) Done() <-chan struct{} { return c.ctx.Done() } -// Err returns an error of the context +// Err returns an error of the context. func (c *Closer) Err() error { return c.ctx.Err() } -// Close sends a signal to trigger the ctx done channel +// Close sends a signal to trigger the ctx done channel. func (c *Closer) Close() { c.closeFunc() } diff --git a/internal/net/buffer.go b/internal/net/buffer.go index 9ab290e4c..c763f15e4 100644 --- a/internal/net/buffer.go +++ b/internal/net/buffer.go @@ -68,11 +68,12 @@ func NewPacketBuffer() *PacketBuffer { // WriteTo writes a single packet to the buffer. The supplied address will // remain associated with the packet. -func (b *PacketBuffer) WriteTo(p []byte, addr net.Addr) (int, error) { +func (b *PacketBuffer) WriteTo(pkt []byte, addr net.Addr) (int, error) { b.mutex.Lock() if b.closed { b.mutex.Unlock() + return 0, io.ErrClosedPipe } @@ -113,9 +114,10 @@ func (b *PacketBuffer) WriteTo(p []byte, addr net.Addr) (int, error) { // Store the packet at the write pointer. packet := &b.packets[b.write] packet.data.Reset() - n, err := packet.data.Write(p) + n, err := packet.data.Write(pkt) if err != nil { b.mutex.Unlock() + return n, err } packet.addr = addr @@ -145,7 +147,7 @@ func (b *PacketBuffer) WriteTo(p []byte, addr net.Addr) (int, error) { // ReadFrom reads a single packet from the buffer, or blocks until one is // available. -func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { +func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { //nolint:cyclop select { case <-b.readDeadline.Done(): return 0, nil, ErrTimeout @@ -159,6 +161,7 @@ func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) ap := b.packets[b.read] if len(packet) < ap.data.Len() { b.mutex.Unlock() + return 0, nil, io.ErrShortBuffer } @@ -166,6 +169,7 @@ func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) n, err := ap.data.Read(packet) if err != nil { b.mutex.Unlock() + return n, nil, err } @@ -188,6 +192,7 @@ func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) if b.closed { b.mutex.Unlock() + return 0, nil, io.EOF } @@ -212,6 +217,7 @@ func (b *PacketBuffer) Close() (err error) { if b.closed { b.mutex.Unlock() + return nil } @@ -231,5 +237,6 @@ func (b *PacketBuffer) Close() (err error) { // SetReadDeadline sets the read deadline for the buffer. func (b *PacketBuffer) SetReadDeadline(t time.Time) error { b.readDeadline.Set(t) + return nil } diff --git a/internal/net/buffer_test.go b/internal/net/buffer_test.go index e416dde53..87a2a6ce1 100644 --- a/internal/net/buffer_test.go +++ b/internal/net/buffer_test.go @@ -15,12 +15,16 @@ import ( ) func equalInt(t *testing.T, expected, actual int) { + t.Helper() + if expected != actual { t.Errorf("Expected %d got %d", expected, actual) } } func equalUDPAddr(t *testing.T, expected, actual net.Addr) { + t.Helper() + if expected == nil && actual == nil { return } @@ -30,12 +34,14 @@ func equalUDPAddr(t *testing.T, expected, actual net.Addr) { } func equalBytes(t *testing.T, expected, actual []byte) { + t.Helper() + if !bytes.Equal(expected, actual) { t.Errorf("Expected %v got %v", expected, actual) } } -func TestBuffer(t *testing.T) { +func TestBuffer(t *testing.T) { //nolint:cyclop buffer := NewPacketBuffer() packet := make([]byte, 4) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") @@ -243,6 +249,7 @@ func TestBufferAsync(t *testing.T) { n, raddr, rErr := buffer.ReadFrom(packet) if rErr != nil { done <- rErr.Error() + return } @@ -281,7 +288,9 @@ func TestBufferAsync(t *testing.T) { } } -func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam +func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam,cyclop + b.Helper() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") if err != nil { b.Fatalf("net.ResolveUDPAddr: %v", err) @@ -335,7 +344,7 @@ func BenchmarkBufferWR1400(b *testing.B) { benchmarkBufferWR(b, 1400, false, 128) } -// Here, the buffer never becomes empty, which forces wraparound +// Here, the buffer never becomes empty, which forces wraparound. func BenchmarkBufferWWR14(b *testing.B) { benchmarkBufferWR(b, 14, true, 128) } @@ -349,6 +358,8 @@ func BenchmarkBufferWWR1400(b *testing.B) { } func benchmarkBuffer(b *testing.B, size int64) { + b.Helper() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") if err != nil { b.Fatalf("net.ResolveUDPAddr: %v", err) @@ -366,6 +377,7 @@ func benchmarkBuffer(b *testing.B, size int64) { break } else if err != nil { b.Error(err) + break } } diff --git a/internal/net/udp/packet_conn.go b/internal/net/udp/packet_conn.go index da5b6fae8..e3e214ce9 100644 --- a/internal/net/udp/packet_conn.go +++ b/internal/net/udp/packet_conn.go @@ -34,13 +34,13 @@ const ( defaultListenBacklog = 128 // same as Linux default ) -// Typed errors +// Typed errors. var ( ErrClosedListener = errors.New("udp: listener closed") ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") ) -// listener augments a connection-oriented Listener over a UDP PacketConn +// listener augments a connection-oriented Listener over a UDP PacketConn. type listener struct { pConn *net.UDPConn @@ -68,10 +68,12 @@ func (l *listener) Accept() (net.PacketConn, net.Addr, error) { select { case c := <-l.acceptCh: l.connWG.Add(1) + return c, c.raddr, nil case <-l.readDoneCh: err, _ := l.errRead.Load().(error) + return nil, nil, err case <-l.doneCh: @@ -168,7 +170,7 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.Pack return nil, err } - l := &listener{ + packetListener := &listener{ pConn: conn, acceptCh: make(chan *PacketConn, lc.Backlog), conns: make(map[string]*PacketConn), @@ -179,20 +181,20 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.Pack readDoneCh: make(chan struct{}), } - l.accepting.Store(true) - l.connWG.Add(1) - l.readWG.Add(2) // wait readLoop and Close execution routine + packetListener.accepting.Store(true) + packetListener.connWG.Add(1) + packetListener.readWG.Add(2) // wait readLoop and Close execution routine - go l.readLoop() + go packetListener.readLoop() go func() { - l.connWG.Wait() - if err := l.pConn.Close(); err != nil { - l.errClose.Store(err) + packetListener.connWG.Wait() + if err := packetListener.pConn.Close(); err != nil { + packetListener.errClose.Store(err) } - l.readWG.Done() + packetListener.readWG.Done() }() - return l, nil + return packetListener, nil } // Listen creates a new listener using default ListenConfig. @@ -212,6 +214,7 @@ func (l *listener) readLoop() { n, raddr, err := l.pConn.ReadFrom(buf) if err != nil { l.errRead.Store(err) + return } conn, ok, err := l.getConn(raddr, buf[:n]) @@ -225,7 +228,7 @@ func (l *listener) readLoop() { } // getConn gets an existing connection or creates a new one. -func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { +func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { //nolint:cyclop l.connLock.Lock() defer l.connLock.Unlock() // If we have a custom resolver, use it. @@ -257,6 +260,7 @@ func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error return nil, false, ErrListenQueueExceeded } } + return conn, true, nil } @@ -292,19 +296,19 @@ func (l *listener) newPacketConn(raddr net.Addr) *PacketConn { // ReadFrom reads a single packet payload and its associated remote address from // the underlying buffer. -func (c *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) { - return c.buffer.ReadFrom(p) +func (c *PacketConn) ReadFrom(buff []byte) (int, net.Addr, error) { + return c.buffer.ReadFrom(buff) } -// WriteTo writes len(p) bytes from p to the specified address. -func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { +// WriteTo writes len(payload) bytes from payload to the specified address. +func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (n int, err error) { // If we have a connection identifier, check to see if the outgoing packet // sets it. if c.listener.connIdentifier != nil { id := c.id.Load() // Only update establish identifier if we haven't already done so. if id == nil { - candidate, ok := c.listener.connIdentifier(p) + candidate, ok := c.listener.connIdentifier(payload) // If we have an identifier, add entry to connection map. if ok { c.listener.connLock.Lock() @@ -340,10 +344,11 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, context.DeadlineExceeded default: } - return c.listener.pConn.WriteTo(p, addr) + + return c.listener.pConn.WriteTo(payload, addr) } -// Close closes the conn and releases any Read calls +// Close closes the conn and releases any Read calls. func (c *PacketConn) Close() error { var err error c.doneOnce.Do(func() { @@ -390,6 +395,7 @@ func (c *PacketConn) LocalAddr() net.Addr { // SetDeadline implements net.PacketConn.SetDeadline. func (c *PacketConn) SetDeadline(t time.Time) error { c.writeDeadline.Set(t) + return c.SetReadDeadline(t) } diff --git a/internal/net/udp/packet_conn_test.go b/internal/net/udp/packet_conn_test.go index 2e9f4a063..53b3c06d5 100644 --- a/internal/net/udp/packet_conn_test.go +++ b/internal/net/udp/packet_conn_test.go @@ -51,6 +51,7 @@ func fromPC(p net.PacketConn, raddr net.Addr) *rw { func (r *rw) Read(p []byte) (int, error) { n, _, err := r.p.ReadFrom(p) + return n, err } @@ -59,6 +60,8 @@ func (r *rw) Write(p []byte) (int, error) { } func stressDuplex(t *testing.T) { + t.Helper() + listener, ca, cb, err := pipe() if err != nil { t.Fatal(err) @@ -135,6 +138,7 @@ func TestListenerCloseUnaccepted(t *testing.T) { conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) if dErr != nil { t.Error(dErr) + continue } if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { @@ -153,7 +157,7 @@ func TestListenerCloseUnaccepted(t *testing.T) { } } -func TestListenerAcceptFilter(t *testing.T) { +func TestListenerAcceptFilter(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -220,6 +224,7 @@ func TestListenerAcceptFilter(t *testing.T) { if !errors.Is(aArr, ErrClosedListener) { t.Error(aArr) } + return } close(chAccepted) @@ -246,7 +251,7 @@ func TestListenerAcceptFilter(t *testing.T) { } } -func TestListenerConcurrent(t *testing.T) { +func TestListenerConcurrent(t *testing.T) { //nolint:gocyclo,cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -269,6 +274,7 @@ func TestListenerConcurrent(t *testing.T) { conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) if dErr != nil { t.Error(dErr) + continue } if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { @@ -285,6 +291,7 @@ func TestListenerConcurrent(t *testing.T) { conn, _, lErr := listener.Accept() if lErr != nil { t.Error(lErr) + continue } b := make([]byte, 1) @@ -367,7 +374,7 @@ func getConfig() (string, *net.UDPAddr) { return "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} } -func TestConnClose(t *testing.T) { +func TestConnClose(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 5) defer lim.Stop() @@ -376,7 +383,7 @@ func TestConnClose(t *testing.T) { report := test.CheckRoutines(t) defer report() - l, ca, cb, errPipe := pipe() + udpListener, ca, cb, errPipe := pipe() if errPipe != nil { t.Fatal(errPipe) } @@ -386,7 +393,7 @@ func TestConnClose(t *testing.T) { if err := cb.Close(); err != nil { t.Errorf("Failed to close B side: %v", err) } - if err := l.Close(); err != nil { + if err := udpListener.Close(); err != nil { t.Errorf("Failed to close listener: %v", err) } }) @@ -395,12 +402,12 @@ func TestConnClose(t *testing.T) { report := test.CheckRoutines(t) defer report() - l, ca, cb, errPipe := pipe() + udpListener, ca, cb, errPipe := pipe() if errPipe != nil { t.Fatal(errPipe) } // Close l.pConn to inject error. - if err := l.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert + if err := udpListener.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert t.Error(err) } @@ -410,7 +417,7 @@ func TestConnClose(t *testing.T) { if err := ca.Close(); err != nil { t.Errorf("Failed to close B side: %v", err) } - if err := l.Close(); err == nil { + if err := udpListener.Close(); err == nil { t.Errorf("Error is not propagated to Listener.Close") } }) @@ -447,7 +454,7 @@ func TestConnClose(t *testing.T) { report := test.CheckRoutines(t) defer report() - l, ca, cb, errPipe := pipe() + listener, ca, cb, errPipe := pipe() if errPipe != nil { t.Fatal(errPipe) } @@ -474,13 +481,13 @@ func TestConnClose(t *testing.T) { if err := cb.Close(); err != nil { t.Errorf("Failed to close A side: %v", err) } - if err := l.Close(); err != nil { + if err := listener.Close(); err != nil { t.Errorf("Failed to close listener: %v", err) } }) } -func TestListenerCustomConnIDs(t *testing.T) { +func TestListenerCustomConnIDs(t *testing.T) { //nolint:gocyclo,cyclop,maintidx const helloPayload, setPayload = "hello", "set" const serverCount, clientCount = 5, 20 // Limit runtime in case of deadlocks. @@ -507,6 +514,7 @@ func TestListenerCustomConnIDs(t *testing.T) { if p.Payload == helloPayload { return "", false } + return fmt.Sprint(p.ID), true }, // Use the outgoing "set" payload to add an identifier for a connection. @@ -518,6 +526,7 @@ func TestListenerCustomConnIDs(t *testing.T) { if p.Payload == setPayload { return fmt.Sprint(p.ID), true } + return "", false }, }).Listen(network, addr) @@ -543,26 +552,30 @@ func TestListenerCustomConnIDs(t *testing.T) { conn, _, err := listener.Accept() if err != nil { t.Error(err) + return } buf := make([]byte, 100) n, raddr, rErr := conn.ReadFrom(buf) if rErr != nil { t.Error(err) + return } - var p pkt - if uErr := json.Unmarshal(buf[:n], &p); uErr != nil { + var udpPkt pkt + if uErr := json.Unmarshal(buf[:n], &udpPkt); uErr != nil { t.Error(err) + return } // First message should be a hello and custom connection // ID function will use remote address as identifier. - if p.Payload != helloPayload { + if udpPkt.Payload != helloPayload { t.Error("Expected hello message") + return } - connID := p.ID + connID := udpPkt.ID // Send set message to associate ID with this connection. buf, err = json.Marshal(&pkt{ @@ -571,10 +584,12 @@ func TestListenerCustomConnIDs(t *testing.T) { }) if err != nil { t.Error(err) + return } if _, wErr := conn.WriteTo(buf, raddr); wErr != nil { t.Error(wErr) + return } // Signal to the corresponding clients that connection ID has been @@ -587,25 +602,29 @@ func TestListenerCustomConnIDs(t *testing.T) { n, _, err := conn.ReadFrom(buf) if err != nil { t.Error(err) + return } - var p pkt - if err := json.Unmarshal(buf[:n], &p); err != nil { + var udpPkt pkt + if err := json.Unmarshal(buf[:n], &udpPkt); err != nil { t.Error(err) + return } - if p.ID != connID { - t.Errorf("Expected connection ID %d, but got %d", connID, p.ID) + if udpPkt.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, udpPkt.ID) + return } // Ensure we only ever receive one message from // a given client. clientMapMu.Lock() - if _, ok := clientMap[p.Payload]; ok { - t.Errorf("Multiple messages from single client %s", p.Payload) + if _, ok := clientMap[udpPkt.Payload]; ok { + t.Errorf("Multiple messages from single client %s", udpPkt.Payload) + return } - clientMap[p.Payload] = struct{}{} + clientMap[udpPkt.Payload] = struct{}{} clientMapMu.Unlock() } if err := conn.Close(); err != nil { @@ -623,6 +642,7 @@ func TestListenerCustomConnIDs(t *testing.T) { conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) if dErr != nil { t.Error(dErr) + return } hbuf, err := json.Marshal(&pkt{ @@ -631,35 +651,41 @@ func TestListenerCustomConnIDs(t *testing.T) { }) if err != nil { t.Error(err) + return } if _, wErr := conn.Write(hbuf); wErr != nil { t.Error(wErr) + return } - var p pkt + var udpPacket pkt buf := make([]byte, 100) n, err := conn.Read(buf) if err != nil { t.Error(err) + return } - if err := json.Unmarshal(buf[:n], &p); err != nil { + if err := json.Unmarshal(buf[:n], &udpPacket); err != nil { t.Error(err) + return } // Second message should be a set and custom connection identifier // function will update the connection ID from remote address to the // supplied ID. - if p.Payload != "set" { + if udpPacket.Payload != "set" { t.Error("Expected set message") + return } // Ensure the connection ID matches what the "hello" message // indicated. - if p.ID != connID { - t.Errorf("Expected connection ID %d, but got %d", connID, p.ID) + if udpPacket.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, udpPacket.ID) + return } // Close connection. We will reconnect from a different remote @@ -681,6 +707,7 @@ func TestListenerCustomConnIDs(t *testing.T) { conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) if dErr != nil { t.Error(dErr) + return } // Send a packet with a connection ID and this client's local @@ -691,10 +718,12 @@ func TestListenerCustomConnIDs(t *testing.T) { }) if err != nil { t.Error(err) + return } if _, wErr := conn.Write(buf); wErr != nil { t.Error(wErr) + return } if cErr := conn.Close(); cErr != nil { diff --git a/internal/util/util.go b/internal/util/util.go index 382a0e1cd..8ebbcd44f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -10,7 +10,7 @@ import ( "golang.org/x/crypto/cryptobyte" ) -// BigEndianUint24 returns the value of a big endian uint24 +// BigEndianUint24 returns the value of a big endian uint24. func BigEndianUint24(raw []byte) uint32 { if len(raw) < 3 { return 0 @@ -18,28 +18,30 @@ func BigEndianUint24(raw []byte) uint32 { rawCopy := make([]byte, 4) copy(rawCopy[1:], raw) + return binary.BigEndian.Uint32(rawCopy) } -// PutBigEndianUint24 encodes a uint24 and places into out +// PutBigEndianUint24 encodes a uint24 and places into out. func PutBigEndianUint24(out []byte, in uint32) { tmp := make([]byte, 4) binary.BigEndian.PutUint32(tmp, in) copy(out, tmp[1:]) } -// PutBigEndianUint48 encodes a uint64 and places into out +// PutBigEndianUint48 encodes a uint64 and places into out. func PutBigEndianUint48(out []byte, in uint64) { tmp := make([]byte, 8) binary.BigEndian.PutUint64(tmp, in) copy(out, tmp[2:]) } -// Max returns the larger value +// Max returns the larger value. func Max(a, b int) int { if a > b { return a } + return b } diff --git a/internal/util/util_test.go b/internal/util/util_test.go index c247f68bb..41127c0b8 100644 --- a/internal/util/util_test.go +++ b/internal/util/util_test.go @@ -29,16 +29,19 @@ func TestAddUint48(t *testing.T) { builder: func() *cryptobyte.Builder { var b cryptobyte.Builder b.AddUint64(0xffffffffffffffff) + return &b }(), in: 0xfefcff3cfdfc, want: []byte{255, 255, 255, 255, 255, 255, 255, 255, 254, 252, 255, 60, 253, 252}, }, "ExistingAddUint48AndMore": { + //nolint:lll reason: "Adding a 48-bit unsigned integer to a builder with existing bytes, then adding more bytes, should yield expected result.", builder: func() *cryptobyte.Builder { var b cryptobyte.Builder b.AddUint64(0xffffffffffffffff) + return &b }(), postAdd: func(b *cryptobyte.Builder) { diff --git a/listener.go b/listener.go index 22e56ce52..3583d0308 100644 --- a/listener.go +++ b/listener.go @@ -12,7 +12,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// Listen creates a DTLS listener +// Listen creates a DTLS listener. func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) { if err := validateConfig(config); err != nil { return nil, err @@ -28,6 +28,7 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e if err := h.Unmarshal(pkts[0]); err != nil { return false } + return h.ContentType == protocol.ContentTypeHandshake }, } @@ -41,6 +42,7 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e if err != nil { return nil, err } + return &listener{ config: config, parent: parent, @@ -59,7 +61,7 @@ func NewListener(inner dtlsnet.PacketListener, config *Config) (net.Listener, er }, nil } -// listener represents a DTLS listener +// listener represents a DTLS listener. type listener struct { config *Config parent dtlsnet.PacketListener @@ -72,6 +74,7 @@ func (l *listener) Accept() (net.Conn, error) { if err != nil { return nil, err } + return Server(c, raddr, l.config) } diff --git a/nettest_test.go b/nettest_test.go index dc245b28d..3a712434f 100644 --- a/nettest_test.go +++ b/nettest_test.go @@ -28,6 +28,7 @@ func TestNetTest(t *testing.T) { _ = c1.Close() _ = c2.Close() } + return }) } diff --git a/pkg/crypto/ccm/ccm.go b/pkg/crypto/ccm/ccm.go index 73476f219..bc268deeb 100644 --- a/pkg/crypto/ccm/ccm.go +++ b/pkg/crypto/ccm/ccm.go @@ -65,7 +65,8 @@ func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) { if lensize < 2 || lensize > 8 { return nil, errInvalidNonceSize } - c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} + c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} //nolint:gosec // G114 + return c, nil } @@ -75,13 +76,14 @@ func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) } func maxlen(l uint8, tagsize int) int { mLen := (uint64(1) << (8 * l)) - 1 - if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { + if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { //nolint:gosec // G114 mLen = m64 // The maximum lentgh on a 64bit arch } - if mLen != uint64(int(mLen)) { + if mLen != uint64(int(mLen)) { //nolint:gosec // G114 return math.MaxInt32 - tagsize // We have only 32bit int's } - return int(mLen) + + return int(mLen) //nolint:gosec // G114 } // MaxNonceLength returns the maximum nonce length for a given plaintext length. @@ -90,10 +92,11 @@ func maxlen(l uint8, tagsize int) int { func MaxNonceLength(pdatalen int) int { const tagsize = 16 for L := 2; L <= 8; L++ { - if maxlen(uint8(L), tagsize) >= pdatalen { + if maxlen(uint8(L), tagsize) >= pdatalen { //nolint:gosec // G115 return 15 - L } } + return 0 } @@ -137,20 +140,20 @@ func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { c.b.Encrypt(mac[:], mac[:]) var block [ccmBlockSize]byte - if n := uint64(len(adata)); n > 0 { + if adataLength := uint64(len(adata)); adataLength > 0 { //nolint:nestif // First adata block includes adata length i := 2 - if n <= 0xfeff { - binary.BigEndian.PutUint16(block[:i], uint16(n)) + if adataLength <= 0xfeff { + binary.BigEndian.PutUint16(block[:i], uint16(adataLength)) } else { block[0] = 0xfe block[1] = 0xff - if n < uint64(1<<32) { + if adataLength < uint64(1<<32) { i = 2 + 4 - binary.BigEndian.PutUint32(block[2:i], uint32(n)) + binary.BigEndian.PutUint32(block[2:i], uint32(adataLength)) //nolint:gosec // G115 } else { i = 2 + 8 - binary.BigEndian.PutUint64(block[2:i], n) + binary.BigEndian.PutUint64(block[2:i], adataLength) } } i = copy(block[i:], adata) @@ -170,6 +173,7 @@ func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { // second slice that aliases into it and contains only the extra bytes. If the // original slice has sufficient capacity then no allocation is performed. // From crypto/cipher/gcm.go +// . func sliceForAppend(in []byte, n int) (head, tail []byte) { if total := len(in) + n; cap(in) >= total { head = in[:total] @@ -178,6 +182,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) { copy(head, in) } tail = head[len(in):] + return } @@ -207,6 +212,7 @@ func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte { ret, out := sliceForAppend(dst, len(plaintext)+int(c.M)) stream.XORKeyStream(out, plaintext) copy(out[len(plaintext):], tag) + return ret } @@ -250,5 +256,6 @@ func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) { if subtle.ConstantTimeCompare(tag, expectedTag) != 1 { return nil, errOpen } + return append(dst, plaintext...), nil } diff --git a/pkg/crypto/ccm/ccm_test.go b/pkg/crypto/ccm/ccm_test.go index da88f2a05..c8056c721 100644 --- a/pkg/crypto/ccm/ccm_test.go +++ b/pkg/crypto/ccm/ccm_test.go @@ -19,6 +19,7 @@ func mustHexDecode(s string) []byte { if err != nil { panic(err) } + return r } @@ -32,7 +33,7 @@ var ( // ClearHeaderOctets: Input with X cleartext header octets // Data: Input with X cleartext header octets // M: length(CBC-MAC) -// Nonce: Nonce +// Nonce: Nonce. type vector struct { AESKey []byte CipherText []byte @@ -42,7 +43,7 @@ type vector struct { Nonce []byte } -func TestRFC3610Vectors(t *testing.T) { +func TestRFC3610Vectors(t *testing.T) { //nolint:maintidx cases := []vector{ // Vectors 1-12 { @@ -62,8 +63,10 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00000004030201a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 8, @@ -86,56 +89,70 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00000007060504a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 8, Nonce: mustHexDecode("00000008070605a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), M: 10, Nonce: mustHexDecode("00000009080706a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), M: 10, Nonce: mustHexDecode("0000000a090807a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 10, Nonce: mustHexDecode("0000000b0a0908a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), M: 10, Nonce: mustHexDecode("0000000c0b0a09a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), M: 10, Nonce: mustHexDecode("0000000d0c0b0aa0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 10, @@ -159,8 +176,10 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("0033568ef7b2633c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506", + ), ClearHeaderOctets: 8, Data: mustHexDecode("aa6cfa36cae86b40b916e0eacc1c00d7dcec68ec0b3bbb1a02de8a2d1aa346132e"), M: 8, @@ -183,56 +202,70 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00f8b678094e3b3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551", + ), ClearHeaderOctets: 12, Data: mustHexDecode("cd9044d2b71fdb8120ea60c06435acbafb11a82e2f071d7ca4a5ebd93a803ba87f"), M: 8, Nonce: mustHexDecode("00d560912d3f703c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba", + ), ClearHeaderOctets: 8, Data: mustHexDecode("d85bc7e69f944fb88a19b950bcf71a018e5e6701c91787659809d67dbedd18"), M: 10, Nonce: mustHexDecode("0042fff8f1951c3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944", + ), ClearHeaderOctets: 8, Data: mustHexDecode("74a0ebc9069f5b371761433c37c5a35fc1f39f406302eb907c6163be38c98437"), M: 10, Nonce: mustHexDecode("00920f40e56cdc3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce", + ), ClearHeaderOctets: 8, Data: mustHexDecode("44a3aa3aae6475caa434a8e58500c6e41530538862d686ea9e81301b5ae4226bfa"), M: 10, Nonce: mustHexDecode("0027ca0c7120bc3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197", + ), ClearHeaderOctets: 12, Data: mustHexDecode("ec46bb63b02520c33c49fd70b96b49e21d621741632875db7f6c9243d2d7c2"), M: 10, Nonce: mustHexDecode("005b8ccbcd9af83c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d", + ), ClearHeaderOctets: 12, Data: mustHexDecode("47a65ac78b3d594227e85e71e2fcfbb880442c731bf95167c8ffd7895e337076"), M: 10, Nonce: mustHexDecode("003ebe94044b9a3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059", + ), ClearHeaderOctets: 12, Data: mustHexDecode("6e37a6ef546d955d34ab6059abf21c0b02feb88f856df4a37381bce3cc128517d4"), M: 10, @@ -245,37 +278,47 @@ func TestRFC3610Vectors(t *testing.T) { t.FailNow() //nolint:revive } - for idx, c := range cases { - c := c + for idx, testCase := range cases { + testCase := testCase t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) { - blk, err := aes.NewCipher(c.AESKey) + blk, err := aes.NewCipher(testCase.AESKey) if err != nil { t.Fatalf("could not initialize AES block cipher from key: %v", err) } - lccm, err := NewCCM(blk, c.M, len(c.Nonce)) + lccm, err := NewCCM(blk, testCase.M, len(testCase.Nonce)) if err != nil { t.Fatalf("could not create CCM: %v", err) } t.Run("seal", func(t *testing.T) { var dst []byte - dst = lccm.Seal(dst, c.Nonce, c.Data[c.ClearHeaderOctets:], c.Data[:c.ClearHeaderOctets]) - if !bytes.Equal(c.CipherText[c.ClearHeaderOctets:], dst) { + dst = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) + if !bytes.Equal(testCase.CipherText[testCase.ClearHeaderOctets:], dst) { t.Fatalf("ciphertext does not match, wanted %v, got %v", - c.CipherText[c.ClearHeaderOctets:], dst) + testCase.CipherText[testCase.ClearHeaderOctets:], dst) } }) t.Run("open", func(t *testing.T) { var dst []byte - dst, err = lccm.Open(dst, c.Nonce, c.CipherText[c.ClearHeaderOctets:], c.CipherText[:c.ClearHeaderOctets]) + dst, err = lccm.Open( + dst, + testCase.Nonce, + testCase.CipherText[testCase.ClearHeaderOctets:], + testCase.CipherText[:testCase.ClearHeaderOctets], + ) if err != nil { t.Fatalf("failed to unseal: %v", err) } - if !bytes.Equal(c.Data[c.ClearHeaderOctets:], dst) { + if !bytes.Equal(testCase.Data[testCase.ClearHeaderOctets:], dst) { t.Fatalf("plaintext does not match, wanted %v, got %v", - c.Data[c.ClearHeaderOctets:], dst) + testCase.Data[testCase.ClearHeaderOctets:], dst) } }) }) @@ -363,21 +406,26 @@ func TestSealError(t *testing.T) { t.Fatalf("could not create CCM: %v", err) } - for name, c := range cases { - c := c + for name, testCase := range cases { + testCase := testCase t.Run(name, func(t *testing.T) { defer func() { err, ok := recover().(error) if !ok { - t.Errorf("expected panic '%v', got '%v'", c.err, err) + t.Errorf("expected panic '%v', got '%v'", testCase.err, err) } - if !errors.Is(err, c.err) { - t.Errorf("expected panic '%v', got '%v'", c.err, err) + if !errors.Is(err, testCase.err) { + t.Errorf("expected panic '%v', got '%v'", testCase.err, err) } }() var dst []byte - _ = lccm.Seal(dst, c.Nonce, c.Data[c.ClearHeaderOctets:], c.Data[:c.ClearHeaderOctets]) + _ = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) }) } } diff --git a/pkg/crypto/ciphersuite/cbc.go b/pkg/crypto/ciphersuite/cbc.go index 68a080d5d..ab2588f9f 100644 --- a/pkg/crypto/ciphersuite/cbc.go +++ b/pkg/crypto/ciphersuite/cbc.go @@ -24,15 +24,18 @@ type cbcMode interface { SetIV([]byte) } -// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type CBC struct { writeCBC, readCBC cbcMode writeMac, readMac []byte h prf.HashFunc } -// NewCBC creates a DTLS CBC Cipher -func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) { +// NewCBC creates a DTLS CBC Cipher. +func NewCBC( + localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, + hashFunc prf.HashFunc, +) (*CBC, error) { writeBlock, err := aes.NewCipher(localKey) if err != nil { return nil, err @@ -59,11 +62,11 @@ func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMa readCBC: readCBC, readMac: remoteMac, - h: h, + h: hashFunc, }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { payload := raw[pkt.Header.Size():] raw = raw[:pkt.Header.Size()] @@ -101,29 +104,29 @@ func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) // Set IV + Encrypt + Prepend IV c.writeCBC.SetIV(iv) c.writeCBC.CryptBlocks(payload, payload) - payload = append(iv, payload...) + payload = append(iv, payload...) //nolint:makezero // todo: FIX // Prepend unencrypted header with encrypted payload raw = append(raw, payload...) // Update recordLayer size to include IV+MAC+Padding - binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115 return raw, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (c *CBC) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CBC) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { blockSize := c.readCBC.BlockSize() mac := c.h() - if err := h.Unmarshal(in); err != nil { + if err := header.Unmarshal(in); err != nil { return nil, err } - body := in[h.Size():] + body := in[header.Size():] switch { - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize): @@ -154,21 +157,33 @@ func (c *CBC) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { expectedMAC := body[dataEnd : dataEnd+macSize] var err error var actualMAC []byte - if h.ContentType == protocol.ContentTypeConnectionID { - actualMAC, err = c.hmacCID(h.Epoch, h.SequenceNumber, h.Version, body[:dataEnd], c.readMac, c.h, h.ConnectionID) + if header.ContentType == protocol.ContentTypeConnectionID { + actualMAC, err = c.hmacCID( + header.Epoch, header.SequenceNumber, header.Version, body[:dataEnd], c.readMac, c.h, header.ConnectionID, + ) } else { - actualMAC, err = c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h) + actualMAC, err = c.hmac( + header.Epoch, header.SequenceNumber, header.ContentType, header.Version, body[:dataEnd], c.readMac, c.h, + ) } // Compute Local MAC and compare if err != nil || !hmac.Equal(actualMAC, expectedMAC) { return nil, errInvalidMAC } - return append(in[:h.Size()], body[:dataEnd]...), nil + return append(in[:header.Size()], body[:dataEnd]...), nil } -func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) { - h := hmac.New(hf, key) +func (c *CBC) hmac( + epoch uint16, + sequenceNumber uint64, + contentType protocol.ContentType, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, +) ([]byte, error) { + hmacHash := hmac.New(hf, key) msg := make([]byte, 13) @@ -177,51 +192,59 @@ func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.Con msg[8] = byte(contentType) msg[9] = protocolVersion.Major msg[10] = protocolVersion.Minor - binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) + binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) //nolint:gosec //G115 - if _, err := h.Write(msg); err != nil { + if _, err := hmacHash.Write(msg); err != nil { return nil, err } - if _, err := h.Write(payload); err != nil { + if _, err := hmacHash.Write(payload); err != nil { return nil, err } - return h.Sum(nil), nil + return hmacHash.Sum(nil), nil } // hmacCID calculates a MAC according to // https://datatracker.ietf.org/doc/html/rfc9146#section-5.1 -func (c *CBC) hmacCID(epoch uint16, sequenceNumber uint64, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash, cid []byte) ([]byte, error) { +func (c *CBC) hmacCID( + epoch uint16, + sequenceNumber uint64, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, + cid []byte, +) ([]byte, error) { // Must unmarshal inner plaintext in orde to perform MAC. ip := &recordlayer.InnerPlaintext{} if err := ip.Unmarshal(payload); err != nil { return nil, err } - h := hmac.New(hf, key) + hmacHash := hmac.New(hf, key) var msg cryptobyte.Builder msg.AddUint64(seqNumPlaceholder) msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) - msg.AddUint8(uint8(len(cid))) + msg.AddUint8(uint8(len(cid))) //nolint:gosec //G115 msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) msg.AddUint8(protocolVersion.Major) msg.AddUint8(protocolVersion.Minor) msg.AddUint16(epoch) util.AddUint48(&msg, sequenceNumber) msg.AddBytes(cid) - msg.AddUint16(uint16(len(payload))) + msg.AddUint16(uint16(len(payload))) //nolint:gosec //G115 msg.AddBytes(ip.Content) msg.AddUint8(uint8(ip.RealType)) msg.AddBytes(make([]byte, ip.Zeros)) - if _, err := h.Write(msg.BytesOrPanic()); err != nil { + if _, err := hmacHash.Write(msg.BytesOrPanic()); err != nil { return nil, err } - if _, err := h.Write(payload); err != nil { + if _, err := hmacHash.Write(payload); err != nil { return nil, err } - return h.Sum(nil), nil + return hmacHash.Sum(nil), nil } diff --git a/pkg/crypto/ciphersuite/ccm.go b/pkg/crypto/ciphersuite/ccm.go index 4c296e2bb..9a40cae8f 100644 --- a/pkg/crypto/ciphersuite/ccm.go +++ b/pkg/crypto/ciphersuite/ccm.go @@ -14,24 +14,24 @@ import ( "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// CCMTagLen is the length of Authentication Tag +// CCMTagLen is the length of Authentication Tag. type CCMTagLen int -// CCM Enums +// CCM Enums. const ( CCMTagLength8 CCMTagLen = 8 CCMTagLength CCMTagLen = 16 ccmNonceLength = 12 ) -// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type CCM struct { localCCM, remoteCCM ccm.CCM localWriteIV, remoteWriteIV []byte tagLen CCMTagLen } -// NewCCM creates a DTLS GCM Cipher +// NewCCM creates a DTLS GCM Cipher. func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) { localBlock, err := aes.NewCipher(localKey) if err != nil { @@ -60,7 +60,7 @@ func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV [ }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { payload := raw[pkt.Header.Size():] raw = raw[:pkt.Header.Size()] @@ -82,36 +82,38 @@ func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) raw = append(raw, encryptedPayload...) // Update recordLayer size to include explicit nonce - binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115 + return raw, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (c *CCM) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { - if err := h.Unmarshal(in); err != nil { +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + if err := header.Unmarshal(in); err != nil { return nil, err } switch { - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil - case len(in) <= (8 + h.Size()): + case len(in) <= (8 + header.Size()): return nil, errNotEnoughRoomForNonce } - nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[h.Size():h.Size()+8]...) - out := in[h.Size()+8:] + nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[header.Size():header.Size()+8]...) + out := in[header.Size()+8:] var additionalData []byte - if h.ContentType == protocol.ContentTypeConnectionID { - additionalData = generateAEADAdditionalDataCID(&h, len(out)-int(c.tagLen)) + if header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&header, len(out)-int(c.tagLen)) } else { - additionalData = generateAEADAdditionalData(&h, len(out)-int(c.tagLen)) + additionalData = generateAEADAdditionalData(&header, len(out)-int(c.tagLen)) } var err error out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData) if err != nil { return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint } - return append(in[:h.Size()], out...), nil + + return append(in[:header.Size()], out...), nil } diff --git a/pkg/crypto/ciphersuite/ciphersuite.go b/pkg/crypto/ciphersuite/ciphersuite.go index 90ddf6105..5c01de580 100644 --- a/pkg/crypto/ciphersuite/ciphersuite.go +++ b/pkg/crypto/ciphersuite/ciphersuite.go @@ -21,10 +21,14 @@ const ( ) var ( - errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113 - errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113 - errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113 - errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} //nolint:goerr113 + //nolint:goerr113 + errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} + //nolint:goerr113 + errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} + //nolint:goerr113 + errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} + //nolint:goerr113 + errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} ) func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { @@ -37,6 +41,7 @@ func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { additionalData[8] = byte(h.ContentType) additionalData[9] = h.Version.Major additionalData[10] = h.Version.Minor + //nolint:gosec //G115 binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen)) return additionalData[:] @@ -45,20 +50,20 @@ func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { // generateAEADAdditionalDataCID generates additional data for AEAD ciphers // according to https://datatracker.ietf.org/doc/html/rfc9146#name-aead-ciphers func generateAEADAdditionalDataCID(h *recordlayer.Header, payloadLen int) []byte { - var b cryptobyte.Builder - - b.AddUint64(seqNumPlaceholder) - b.AddUint8(uint8(protocol.ContentTypeConnectionID)) - b.AddUint8(uint8(len(h.ConnectionID))) - b.AddUint8(uint8(protocol.ContentTypeConnectionID)) - b.AddUint8(h.Version.Major) - b.AddUint8(h.Version.Minor) - b.AddUint16(h.Epoch) - util.AddUint48(&b, h.SequenceNumber) - b.AddBytes(h.ConnectionID) - b.AddUint16(uint16(payloadLen)) - - return b.BytesOrPanic() + var builder cryptobyte.Builder + + builder.AddUint64(seqNumPlaceholder) + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(uint8(len(h.ConnectionID))) //nolint:gosec //G115 + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(h.Version.Major) + builder.AddUint8(h.Version.Minor) + builder.AddUint16(h.Epoch) + util.AddUint48(&builder, h.SequenceNumber) + builder.AddBytes(h.ConnectionID) + builder.AddUint16(uint16(payloadLen)) //nolint:gosec //G115 + + return builder.BytesOrPanic() } // examinePadding returns, in constant time, the length of the padding to remove @@ -72,9 +77,9 @@ func examinePadding(payload []byte) (toRemove int, good byte) { } paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) + t := uint(len(payload)-1) - uint(paddingLen) //nolint:gosec //G115 // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) + good = byte(int32(^t) >> 31) //nolint:gosec //G115 // The maximum possible padding length plus the actual length field toCheck := 256 @@ -84,9 +89,9 @@ func examinePadding(payload []byte) (toRemove int, good byte) { } for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) + t := uint(paddingLen) - uint(i) //nolint:gosec //G115 // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) + mask := byte(int32(^t) >> 31) //nolint:gosec //G115 b := payload[len(payload)-1-i] good &^= mask&paddingLen ^ mask&b } @@ -96,7 +101,7 @@ func examinePadding(payload []byte) (toRemove int, good byte) { good &= good << 4 good &= good << 2 good &= good << 1 - good = uint8(int8(good) >> 7) + good = uint8(int8(good) >> 7) //nolint:gosec //G115 toRemove = int(paddingLen) + 1 diff --git a/pkg/crypto/ciphersuite/ciphersuite_test.go b/pkg/crypto/ciphersuite/ciphersuite_test.go index 100d718d9..3767d9661 100644 --- a/pkg/crypto/ciphersuite/ciphersuite_test.go +++ b/pkg/crypto/ciphersuite/ciphersuite_test.go @@ -29,7 +29,10 @@ func TestGenerateAEADAdditionalDataCID(t *testing.T) { SequenceNumber: 277, }, payloadLen: 1784, - expected: []byte{255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248}, + expected: []byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, + 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248, + }, }, "IgnoreContentType": { reason: "Should use Connection ID content type regardless of header content type.", @@ -41,7 +44,10 @@ func TestGenerateAEADAdditionalDataCID(t *testing.T) { SequenceNumber: 277, }, payloadLen: 1784, - expected: []byte{255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248}, + expected: []byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, + 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248, + }, }, } for name, tc := range cases { diff --git a/pkg/crypto/ciphersuite/gcm.go b/pkg/crypto/ciphersuite/gcm.go index a7f828246..1c50dd967 100644 --- a/pkg/crypto/ciphersuite/gcm.go +++ b/pkg/crypto/ciphersuite/gcm.go @@ -19,13 +19,13 @@ const ( gcmNonceLength = 12 ) -// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type GCM struct { localGCM, remoteGCM cipher.AEAD localWriteIV, remoteWriteIV []byte } -// NewGCM creates a DTLS GCM Cipher +// NewGCM creates a DTLS GCM Cipher. func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) { localBlock, err := aes.NewCipher(localKey) if err != nil { @@ -53,7 +53,7 @@ func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, erro }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { payload := raw[pkt.Header.Size():] raw = raw[:pkt.Header.Size()] @@ -77,36 +77,38 @@ func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) copy(r[len(raw)+len(nonce[4:]):], encryptedPayload) // Update recordLayer size to include explicit nonce - binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) + binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) //nolint:gosec //G115 + return r, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (g *GCM) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { - err := h.Unmarshal(in) +// Decrypt decrypts a DTLS RecordLayer message. +func (g *GCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + err := header.Unmarshal(in) switch { case err != nil: return nil, err - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil - case len(in) <= (8 + h.Size()): + case len(in) <= (8 + header.Size()): return nil, errNotEnoughRoomForNonce } nonce := make([]byte, 0, gcmNonceLength) - nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[h.Size():h.Size()+8]...) - out := in[h.Size()+8:] + nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[header.Size():header.Size()+8]...) + out := in[header.Size()+8:] var additionalData []byte - if h.ContentType == protocol.ContentTypeConnectionID { - additionalData = generateAEADAdditionalDataCID(&h, len(out)-gcmTagLength) + if header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&header, len(out)-gcmTagLength) } else { - additionalData = generateAEADAdditionalData(&h, len(out)-gcmTagLength) + additionalData = generateAEADAdditionalData(&header, len(out)-gcmTagLength) } out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData) if err != nil { return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint } - return append(in[:h.Size()], out...), nil + + return append(in[:header.Size()], out...), nil } diff --git a/pkg/crypto/clientcertificate/client_certificate.go b/pkg/crypto/clientcertificate/client_certificate.go index ddfa39ebe..0a510d4d4 100644 --- a/pkg/crypto/clientcertificate/client_certificate.go +++ b/pkg/crypto/clientcertificate/client_certificate.go @@ -10,13 +10,13 @@ package clientcertificate // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2 type Type byte -// ClientCertificateType enums +// ClientCertificateType enums. const ( RSASign Type = 1 ECDSASign Type = 64 ) -// Types returns all valid ClientCertificate Types +// Types returns all valid ClientCertificate Types. func Types() map[Type]bool { return map[Type]bool{ RSASign: true, diff --git a/pkg/crypto/elliptic/elliptic.go b/pkg/crypto/elliptic/elliptic.go index 126523872..b98fb4f9c 100644 --- a/pkg/crypto/elliptic/elliptic.go +++ b/pkg/crypto/elliptic/elliptic.go @@ -20,12 +20,12 @@ var errInvalidNamedCurve = errors.New("invalid named curve") // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 type CurvePointFormat byte -// CurvePointFormat enums +// CurvePointFormat enums. const ( CurvePointFormatUncompressed CurvePointFormat = 0 ) -// Keypair is a Curve with a Private/Public Keypair +// Keypair is a Curve with a Private/Public Keypair. type Keypair struct { Curve Curve PublicKey []byte @@ -37,12 +37,12 @@ type Keypair struct { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10 type CurveType byte -// CurveType enums +// CurveType enums. const ( CurveTypeNamedCurve CurveType = 0x03 ) -// CurveTypes returns all known curves +// CurveTypes returns all known curves. func CurveTypes() map[CurveType]struct{} { return map[CurveType]struct{}{ CurveTypeNamedCurve: {}, @@ -54,7 +54,7 @@ func CurveTypes() map[CurveType]struct{} { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 type Curve uint16 -// Curve enums +// Curve enums. const ( P256 Curve = 0x0017 P384 Curve = 0x0018 @@ -70,10 +70,11 @@ func (c Curve) String() string { case X25519: return "X25519" } + return fmt.Sprintf("%#x", uint16(c)) } -// Curves returns all curves we implement +// Curves returns all curves we implement. func Curves() map[Curve]bool { return map[Curve]bool{ X25519: true, @@ -82,7 +83,7 @@ func Curves() map[Curve]bool { } } -// GenerateKeypair generates a keypair for the given Curve +// GenerateKeypair generates a keypair for the given Curve. func GenerateKeypair(c Curve) (*Keypair, error) { switch c { //nolint:revive case X25519: @@ -95,6 +96,7 @@ func GenerateKeypair(c Curve) (*Keypair, error) { copy(private[:], tmp) curve25519.ScalarBaseMult(&public, &private) + return &Keypair{X25519, public[:], private[:]}, nil case P256: return ellipticCurveKeypair(P256, elliptic.P256(), elliptic.P256()) diff --git a/pkg/crypto/fingerprint/fingerprint.go b/pkg/crypto/fingerprint/fingerprint.go index 7c66265c7..43d4a4000 100644 --- a/pkg/crypto/fingerprint/fingerprint.go +++ b/pkg/crypto/fingerprint/fingerprint.go @@ -16,7 +16,7 @@ var ( errInvalidFingerprintLength = errors.New("fingerprint: invalid fingerprint length") ) -// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm +// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm. func Fingerprint(cert *x509.Certificate, algo crypto.Hash) (string, error) { if !algo.Available() { return "", errHashUnavailable diff --git a/pkg/crypto/fingerprint/fingerprint_test.go b/pkg/crypto/fingerprint/fingerprint_test.go index 3266d1153..0a22d8b6c 100644 --- a/pkg/crypto/fingerprint/fingerprint_test.go +++ b/pkg/crypto/fingerprint/fingerprint_test.go @@ -12,23 +12,29 @@ import ( func TestFingerprint(t *testing.T) { rawCertificate := []byte{ - 0x30, 0x82, 0x01, 0x98, 0x30, 0x82, 0x01, 0x3d, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x00, 0xa9, 0x91, 0x76, 0x0a, 0xcd, 0x97, 0x4c, 0x36, 0xba, - 0xc9, 0xc2, 0x66, 0x91, 0x47, 0x6c, 0xac, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x2b, 0x31, 0x29, 0x30, 0x27, - 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x31, 0x31, 0x30, 0x30, - 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x32, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x30, 0x2b, 0x31, 0x29, - 0x30, 0x27, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, - 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0x9c, 0x12, 0x8e, 0xb5, 0x21, 0x23, 0x9f, - 0x35, 0x5d, 0x39, 0x64, 0xc3, 0x75, 0x81, 0xa4, 0xc8, 0xc8, 0x08, 0x8a, 0xa8, 0x42, 0x30, 0x30, 0x65, 0xb8, 0xb1, 0x3e, 0x4a, 0x51, 0x86, 0xeb, 0xad, - 0x03, 0x02, 0x35, 0x83, 0xc4, 0x19, 0x3a, 0x5b, 0x79, 0x83, 0xec, 0x59, 0x0e, 0x4f, 0x99, 0xb1, 0xd2, 0xf0, 0x50, 0xfa, 0xb8, 0x5f, 0xfc, 0x88, 0xf3, - 0x15, 0xed, 0xb8, 0x14, 0xf0, 0xba, 0xcd, 0xa3, 0x42, 0x30, 0x40, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, - 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, - 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, - 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x49, 0x00, 0x30, 0x46, 0x02, 0x21, 0x00, 0xcd, 0x44, 0xb1, 0xf2, 0x09, - 0xe5, 0xf1, 0xf4, 0xc9, 0x26, 0x95, 0x9a, 0x2d, 0x6d, 0xf3, 0x0c, 0xb8, 0xeb, 0x27, 0x2d, 0x81, 0x19, 0xe9, 0x51, 0xf7, 0xad, 0x64, 0x7d, 0x42, 0x32, - 0x9e, 0xf8, 0x02, 0x21, 0x00, 0xee, 0xad, 0x96, 0x41, 0xf1, 0x12, 0xd0, 0x6b, 0xcd, 0x09, 0xf0, 0x3c, 0x67, 0xb3, 0xdd, 0xed, 0x0a, 0xf1, 0xd8, 0x41, - 0x4f, 0x61, 0xfd, 0x53, 0x1d, 0xf5, 0x27, 0xbe, 0x6d, 0x0b, 0xe2, 0x0d, + 0x30, 0x82, 0x01, 0x98, 0x30, 0x82, 0x01, 0x3d, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x00, 0xa9, 0x91, + 0x76, 0x0a, 0xcd, 0x97, 0x4c, 0x36, 0xba, 0xc9, 0xc2, 0x66, 0x91, 0x47, 0x6c, 0xac, 0x30, 0x0a, 0x06, 0x08, + 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x2b, 0x31, 0x29, 0x30, 0x27, 0x06, 0x03, 0x55, 0x04, + 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x1e, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x31, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x17, 0x0d, + 0x31, 0x39, 0x31, 0x32, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x30, 0x2b, 0x31, 0x29, 0x30, + 0x27, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, + 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0x9c, 0x12, 0x8e, 0xb5, 0x21, + 0x23, 0x9f, 0x35, 0x5d, 0x39, 0x64, 0xc3, 0x75, 0x81, 0xa4, 0xc8, 0xc8, 0x08, 0x8a, 0xa8, 0x42, 0x30, 0x30, + 0x65, 0xb8, 0xb1, 0x3e, 0x4a, 0x51, 0x86, 0xeb, 0xad, 0x03, 0x02, 0x35, 0x83, 0xc4, 0x19, 0x3a, 0x5b, 0x79, + 0x83, 0xec, 0x59, 0x0e, 0x4f, 0x99, 0xb1, 0xd2, 0xf0, 0x50, 0xfa, 0xb8, 0x5f, 0xfc, 0x88, 0xf3, 0x15, 0xed, + 0xb8, 0x14, 0xf0, 0xba, 0xcd, 0xa3, 0x42, 0x30, 0x40, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, + 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, + 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, + 0xff, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x49, 0x00, 0x30, 0x46, + 0x02, 0x21, 0x00, 0xcd, 0x44, 0xb1, 0xf2, 0x09, 0xe5, 0xf1, 0xf4, 0xc9, 0x26, 0x95, 0x9a, 0x2d, 0x6d, 0xf3, + 0x0c, 0xb8, 0xeb, 0x27, 0x2d, 0x81, 0x19, 0xe9, 0x51, 0xf7, 0xad, 0x64, 0x7d, 0x42, 0x32, 0x9e, 0xf8, 0x02, + 0x21, 0x00, 0xee, 0xad, 0x96, 0x41, 0xf1, 0x12, 0xd0, 0x6b, 0xcd, 0x09, 0xf0, 0x3c, 0x67, 0xb3, 0xdd, 0xed, + 0x0a, 0xf1, 0xd8, 0x41, 0x4f, 0x61, 0xfd, 0x53, 0x1d, 0xf5, 0x27, 0xbe, 0x6d, 0x0b, 0xe2, 0x0d, } cert, err := x509.ParseCertificate(rawCertificate) @@ -36,6 +42,7 @@ func TestFingerprint(t *testing.T) { t.Fatal(err) } + //nolint:lll const expectedSHA256 = "60:ef:f5:79:ad:8d:3e:d7:e8:4d:5a:5a:d6:1e:71:2d:47:52:a5:cb:df:34:37:87:10:a5:4e:d7:2a:2c:37:34" actualSHA256, err := Fingerprint(cert, crypto.SHA256) if err != nil { diff --git a/pkg/crypto/fingerprint/hash.go b/pkg/crypto/fingerprint/hash.go index 3f988ffb7..8aacd673d 100644 --- a/pkg/crypto/fingerprint/hash.go +++ b/pkg/crypto/fingerprint/hash.go @@ -22,11 +22,12 @@ func nameToHash() map[string]crypto.Hash { } } -// HashFromString allows looking up a hash algorithm by it's string representation +// HashFromString allows looking up a hash algorithm by it's string representation. func HashFromString(s string) (crypto.Hash, error) { if h, ok := nameToHash()[strings.ToLower(s)]; ok { return h, nil } + return 0, errInvalidHashAlgorithm } @@ -37,5 +38,6 @@ func StringFromHash(hash crypto.Hash) (string, error) { return s, nil } } + return "", errInvalidHashAlgorithm } diff --git a/pkg/crypto/fingerprint/hash_test.go b/pkg/crypto/fingerprint/hash_test.go index b71a7a363..51a29bfce 100644 --- a/pkg/crypto/fingerprint/hash_test.go +++ b/pkg/crypto/fingerprint/hash_test.go @@ -22,7 +22,7 @@ func TestHashFromString(t *testing.T) { t.Fatalf("Unexpected error for valid hash name, got '%v'", err) } if h != crypto.SHA512 { - t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) + t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) //nolint:gosec //G115 } }) t.Run("ValidCaseInsensitiveHashAlgorithm", func(t *testing.T) { @@ -31,6 +31,7 @@ func TestHashFromString(t *testing.T) { t.Fatalf("Unexpected error for valid hash name, got '%v'", err) } if h != crypto.SHA512 { + //nolint:gosec // G115 t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) } }) diff --git a/pkg/crypto/hash/hash.go b/pkg/crypto/hash/hash.go index 9966626e3..a390170fe 100644 --- a/pkg/crypto/hash/hash.go +++ b/pkg/crypto/hash/hash.go @@ -16,7 +16,7 @@ import ( //nolint:gci // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18 type Algorithm uint16 -// Supported hash algorithms +// Supported hash algorithms. const ( None Algorithm = 0 // Blacklisted MD5 Algorithm = 1 // Blacklisted @@ -28,7 +28,7 @@ const ( Ed25519 Algorithm = 8 ) -// String makes hashAlgorithm printable +// String makes hashAlgorithm printable. func (a Algorithm) String() string { switch a { case None: @@ -52,28 +52,34 @@ func (a Algorithm) String() string { } } -// Digest performs a digest on the passed value +// Digest performs a digest on the passed value. func (a Algorithm) Digest(b []byte) []byte { switch a { case None: return nil case MD5: hash := md5.Sum(b) // #nosec + return hash[:] case SHA1: hash := sha1.Sum(b) // #nosec + return hash[:] case SHA224: hash := sha256.Sum224(b) + return hash[:] case SHA256: hash := sha256.Sum256(b) + return hash[:] case SHA384: hash := sha512.Sum384(b) + return hash[:] case SHA512: hash := sha512.Sum512(b) + return hash[:] default: return nil @@ -81,6 +87,7 @@ func (a Algorithm) Digest(b []byte) []byte { } // Insecure returns if the given HashAlgorithm is considered secure in DTLS 1.2 +// . func (a Algorithm) Insecure() bool { switch a { case None, MD5, SHA1: @@ -90,7 +97,7 @@ func (a Algorithm) Insecure() bool { } } -// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm +// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm. func (a Algorithm) CryptoHash() crypto.Hash { switch a { case None: @@ -114,7 +121,7 @@ func (a Algorithm) CryptoHash() crypto.Hash { } } -// Algorithms returns all the supported Hash Algorithms +// Algorithms returns all the supported Hash Algorithms. func Algorithms() map[Algorithm]struct{} { return map[Algorithm]struct{}{ None: {}, diff --git a/pkg/crypto/hash/hash_test.go b/pkg/crypto/hash/hash_test.go index d6660254b..c6ba906ae 100644 --- a/pkg/crypto/hash/hash_test.go +++ b/pkg/crypto/hash/hash_test.go @@ -22,7 +22,10 @@ func TestHashAlgorithm_StringRoundtrip(t *testing.T) { t.Fatalf("fingerprint.HashFromString failed: %v", err) } if hash1 != hash2 { - t.Errorf("Hash algorithm mismatch, input: %d, after roundtrip: %d", int(hash1), int(hash2)) + t.Errorf( + "Hash algorithm mismatch, input: %d, after roundtrip: %d", + int(hash1), int(hash2), //nolint:gosec // G115 + ) } } } diff --git a/pkg/crypto/prf/prf.go b/pkg/crypto/prf/prf.go index b5ac19f79..9eace83b7 100644 --- a/pkg/crypto/prf/prf.go +++ b/pkg/crypto/prf/prf.go @@ -26,10 +26,10 @@ const ( verifyDataServerLabel = "server finished" ) -// HashFunc allows callers to decide what hash is used in PRF +// HashFunc allows callers to decide what hash is used in PRF. type HashFunc func() hash.Hash -// EncryptionKeys is all the state needed for a TLS CipherSuite +// EncryptionKeys is all the state needed for a TLS CipherSuite. type EncryptionKeys struct { MasterSecret []byte ClientMACKey []byte @@ -68,7 +68,7 @@ func (e *EncryptionKeys) String() string { // // https://tools.ietf.org/html/rfc4279#section-2 func PSKPreMasterSecret(psk []byte) []byte { - pskLen := uint16(len(psk)) + pskLen := uint16(len(psk)) //nolint:gosec // G115 out := append(make([]byte, 2+pskLen+2), psk...) binary.BigEndian.PutUint16(out, pskLen) @@ -89,7 +89,7 @@ func EcdhePSKPreMasterSecret(psk, publicKey, privateKey []byte, curve elliptic.C // write preMasterSecret length offset := 0 - binary.BigEndian.PutUint16(out[offset:], uint16(len(preMasterSecret))) + binary.BigEndian.PutUint16(out[offset:], uint16(len(preMasterSecret))) //nolint:gosec // G115 offset += 2 // write preMasterSecret @@ -97,15 +97,16 @@ func EcdhePSKPreMasterSecret(psk, publicKey, privateKey []byte, curve elliptic.C offset += len(preMasterSecret) // write psk length - binary.BigEndian.PutUint16(out[offset:], uint16(len(psk))) + binary.BigEndian.PutUint16(out[offset:], uint16(len(psk))) //nolint:gosec // G115 offset += 2 // write psk copy(out[offset:], psk) + return out, nil } -// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve +// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve. func PreMasterSecret(publicKey, privateKey []byte, curve elliptic.Curve) ([]byte, error) { switch curve { case elliptic.X25519: @@ -129,6 +130,7 @@ func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 ellipticS preMasterSecret := make([]byte, (c2.Params().BitSize+7)>>3) resultBytes := result.Bytes() copy(preMasterSecret[len(preMasterSecret)-len(resultBytes):], resultBytes) + return preMasterSecret, nil } @@ -155,12 +157,13 @@ func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 ellipticS // output data. // // https://tools.ietf.org/html/rfc4346w -func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) { +func PHash(secret, seed []byte, requestedLength int, hashFunc HashFunc) ([]byte, error) { hmacSHA256 := func(key, data []byte) ([]byte, error) { - mac := hmac.New(h, key) + mac := hmac.New(hashFunc, key) if _, err := mac.Write(data); err != nil { return nil, err } + return mac.Sum(nil), nil } @@ -168,7 +171,7 @@ func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) lastRound := seed out := []byte{} - iterations := int(math.Ceil(float64(requestedLength) / float64(h().Size()))) + iterations := int(math.Ceil(float64(requestedLength) / float64(hashFunc().Size()))) for i := 0; i < iterations; i++ { lastRound, err = hmacSHA256(secret, lastRound) if err != nil { @@ -188,18 +191,24 @@ func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) // https://tools.ietf.org/html/rfc7627 func ExtendedMasterSecret(preMasterSecret, sessionHash []byte, h HashFunc) ([]byte, error) { seed := append([]byte(extendedMasterSecretLabel), sessionHash...) + return PHash(preMasterSecret, seed, 48, h) } -// MasterSecret generates a TLS 1.2 MasterSecret +// MasterSecret generates a TLS 1.2 MasterSecret. func MasterSecret(preMasterSecret, clientRandom, serverRandom []byte, h HashFunc) ([]byte, error) { seed := append(append([]byte(masterSecretLabel), clientRandom...), serverRandom...) + return PHash(preMasterSecret, seed, 48, h) } // GenerateEncryptionKeys is the final step TLS 1.2 PRF. Given all state generated so far generates -// the final keys need for encryption -func GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int, h HashFunc) (*EncryptionKeys, error) { +// the final keys need for encryption. +func GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom []byte, + macLen, keyLen, ivLen int, + h HashFunc, +) (*EncryptionKeys, error) { seed := append(append([]byte(keyExpansionLabel), serverRandom...), clientRandom...) keyMaterial, err := PHash(masterSecret, seed, (2*macLen)+(2*keyLen)+(2*ivLen), h) if err != nil { @@ -241,15 +250,16 @@ func prfVerifyData(masterSecret, handshakeBodies []byte, label string, hashFunc } seed := append([]byte(label), h.Sum(nil)...) + return PHash(masterSecret, seed, 12, hashFunc) } -// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message +// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message. func VerifyDataClient(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { return prfVerifyData(masterSecret, handshakeBodies, verifyDataClientLabel, h) } -// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message +// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message. func VerifyDataServer(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { return prfVerifyData(masterSecret, handshakeBodies, verifyDataServerLabel, h) } diff --git a/pkg/crypto/prf/prf_test.go b/pkg/crypto/prf/prf_test.go index fa9b32510..9d2f90451 100644 --- a/pkg/crypto/prf/prf_test.go +++ b/pkg/crypto/prf/prf_test.go @@ -13,9 +13,18 @@ import ( ) func TestPreMasterSecret(t *testing.T) { - privateKey := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f} - publicKey := []byte{0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15} - expectedPreMasterSecret := []byte{0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24} + privateKey := []byte{ + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + } + publicKey := []byte{ + 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, + 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, + } + expectedPreMasterSecret := []byte{ + 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, + 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24, + } preMasterSecret, err := PreMasterSecret(publicKey, privateKey, elliptic.X25519) if err != nil { @@ -26,10 +35,23 @@ func TestPreMasterSecret(t *testing.T) { } func TestMasterSecret(t *testing.T) { - preMasterSecret := []byte{0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24} - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - expectedMasterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + preMasterSecret := []byte{ + 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, + 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24, + } + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + expectedMasterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, + 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, + 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c, + } masterSecret, err := MasterSecret(preMasterSecret, clientRandom, serverRandom, sha256.New) if err != nil { @@ -40,18 +62,32 @@ func TestMasterSecret(t *testing.T) { } func TestEncryptionKeys(t *testing.T) { - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - masterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + masterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, + 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, + 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c, + } expectedEncryptionKeys := &EncryptionKeys{ - MasterSecret: masterSecret, - ClientMACKey: []byte{}, - ServerMACKey: []byte{}, - ClientWriteKey: []byte{0x1b, 0x7d, 0x11, 0x7c, 0x7d, 0x5f, 0x69, 0x0b, 0xc2, 0x63, 0xca, 0xe8, 0xef, 0x60, 0xaf, 0x0f}, - ServerWriteKey: []byte{0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, 0x63, 0x54}, - ClientWriteIV: []byte{0x0e, 0xb2, 0x09, 0x06}, - ServerWriteIV: []byte{0xf7, 0x81, 0xfa, 0xd2}, + MasterSecret: masterSecret, + ClientMACKey: []byte{}, + ServerMACKey: []byte{}, + ClientWriteKey: []byte{ + 0x1b, 0x7d, 0x11, 0x7c, 0x7d, 0x5f, 0x69, 0x0b, 0xc2, 0x63, 0xca, 0xe8, 0xef, 0x60, 0xaf, 0x0f, + }, + ServerWriteKey: []byte{ + 0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, 0x63, 0x54, + }, + ClientWriteIV: []byte{0x0e, 0xb2, 0x09, 0x06}, + ServerWriteIV: []byte{0xf7, 0x81, 0xfa, 0xd2}, } keys, err := GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, 0, 16, 4, sha256.New) @@ -63,15 +99,123 @@ func TestEncryptionKeys(t *testing.T) { } func TestVerifyData(t *testing.T) { - clientHello := []byte{0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00, 0x10, 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00} - serverHello := []byte{0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, 0x05, 0xff, 0x01, 0x00, 0x01, 0x00} - serverCertificate := []byte{0x0b, 0x00, 0x03, 0x2b, 0x00, 0x03, 0x28, 0x00, 0x03, 0x25, 0x30, 0x82, 0x03, 0x21, 0x30, 0x82, 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0} - serverKeyExchange := []byte{0x0c, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, 0x04, 0x01, 0x01, 0x00, 0x04, 0x02, 0xb6, 0x61, 0xf7, 0xc1, 0x91, 0xee, 0x59, 0xbe, 0x45, 0x37, 0x66, 0x39, 0xbd, 0xc3, 0xd4, 0xbb, 0x81, 0xe1, 0x15, 0xca, 0x73, 0xc8, 0x34, 0x8b, 0x52, 0x5b, 0x0d, 0x23, 0x38, 0xaa, 0x14, 0x46, 0x67, 0xed, 0x94, 0x31, 0x02, 0x14, 0x12, 0xcd, 0x9b, 0x84, 0x4c, 0xba, 0x29, 0x93, 0x4a, 0xaa, 0xcc, 0xe8, 0x73, 0x41, 0x4e, 0xc1, 0x1c, 0xb0, 0x2e, 0x27, 0x2d, 0x0a, 0xd8, 0x1f, 0x76, 0x7d, 0x33, 0x07, 0x67, 0x21, 0xf1, 0x3b, 0xf3, 0x60, 0x20, 0xcf, 0x0b, 0x1f, 0xd0, 0xec, 0xb0, 0x78, 0xde, 0x11, 0x28, 0xbe, 0xba, 0x09, 0x49, 0xeb, 0xec, 0xe1, 0xa1, 0xf9, 0x6e, 0x20, 0x9d, 0xc3, 0x6e, 0x4f, 0xff, 0xd3, 0x6b, 0x67, 0x3a, 0x7d, 0xdc, 0x15, 0x97, 0xad, 0x44, 0x08, 0xe4, 0x85, 0xc4, 0xad, 0xb2, 0xc8, 0x73, 0x84, 0x12, 0x49, 0x37, 0x25, 0x23, 0x80, 0x9e, 0x43, 0x12, 0xd0, 0xc7, 0xb3, 0x52, 0x2e, 0xf9, 0x83, 0xca, 0xc1, 0xe0, 0x39, 0x35, 0xff, 0x13, 0xa8, 0xe9, 0x6b, 0xa6, 0x81, 0xa6, 0x2e, 0x40, 0xd3, 0xe7, 0x0a, 0x7f, 0xf3, 0x58, 0x66, 0xd3, 0xd9, 0x99, 0x3f, 0x9e, 0x26, 0xa6, 0x34, 0xc8, 0x1b, 0x4e, 0x71, 0x38, 0x0f, 0xcd, 0xd6, 0xf4, 0xe8, 0x35, 0xf7, 0x5a, 0x64, 0x09, 0xc7, 0xdc, 0x2c, 0x07, 0x41, 0x0e, 0x6f, 0x87, 0x85, 0x8c, 0x7b, 0x94, 0xc0, 0x1c, 0x2e, 0x32, 0xf2, 0x91, 0x76, 0x9e, 0xac, 0xca, 0x71, 0x64, 0x3b, 0x8b, 0x98, 0xa9, 0x63, 0xdf, 0x0a, 0x32, 0x9b, 0xea, 0x4e, 0xd6, 0x39, 0x7e, 0x8c, 0xd0, 0x1a, 0x11, 0x0a, 0xb3, 0x61, 0xac, 0x5b, 0xad, 0x1c, 0xcd, 0x84, 0x0a, 0x6c, 0x8a, 0x6e, 0xaa, 0x00, 0x1a, 0x9d, 0x7d, 0x87, 0xdc, 0x33, 0x18, 0x64, 0x35, 0x71, 0x22, 0x6c, 0x4d, 0xd2, 0xc2, 0xac, 0x41, 0xfb} + clientHello := []byte{ + 0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, + 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, + 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, + 0x9d, 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, + 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00, 0x10, 0x04, 0x01, 0x04, 0x03, + 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, + 0x00, 0x00, 0x12, 0x00, 0x00, + } + serverHello := []byte{ + 0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, 0x05, 0xff, 0x01, 0x00, 0x01, + 0x00, + } + serverCertificate := []byte{ + 0x0b, 0x00, 0x03, 0x2b, 0x00, 0x03, 0x28, 0x00, 0x03, 0x25, 0x30, 0x82, 0x03, 0x21, 0x30, 0x82, + 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, + 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, + 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, + 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, + 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, + 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, + 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, + 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, + 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, + 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, + 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, + 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, + 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, + 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, + 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, + 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, + 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, + 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, + 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, + 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, + 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, + 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, + 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, + 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, + 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, + 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, + 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, + 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, + 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, + 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, + 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, + 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, + 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, + 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, + 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, + 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, + 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, + 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, + 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, + 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, + 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, + 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, + 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, + 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0, + } + serverKeyExchange := []byte{ + 0x0c, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, + 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, + 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, 0x04, 0x01, 0x01, 0x00, 0x04, 0x02, 0xb6, 0x61, + 0xf7, 0xc1, 0x91, 0xee, 0x59, 0xbe, 0x45, 0x37, 0x66, 0x39, 0xbd, 0xc3, 0xd4, 0xbb, 0x81, 0xe1, + 0x15, 0xca, 0x73, 0xc8, 0x34, 0x8b, 0x52, 0x5b, 0x0d, 0x23, 0x38, 0xaa, 0x14, 0x46, 0x67, 0xed, + 0x94, 0x31, 0x02, 0x14, 0x12, 0xcd, 0x9b, 0x84, 0x4c, 0xba, 0x29, 0x93, 0x4a, 0xaa, 0xcc, 0xe8, + 0x73, 0x41, 0x4e, 0xc1, 0x1c, 0xb0, 0x2e, 0x27, 0x2d, 0x0a, 0xd8, 0x1f, 0x76, 0x7d, 0x33, 0x07, + 0x67, 0x21, 0xf1, 0x3b, 0xf3, 0x60, 0x20, 0xcf, 0x0b, 0x1f, 0xd0, 0xec, 0xb0, 0x78, 0xde, 0x11, + 0x28, 0xbe, 0xba, 0x09, 0x49, 0xeb, 0xec, 0xe1, 0xa1, 0xf9, 0x6e, 0x20, 0x9d, 0xc3, 0x6e, 0x4f, + 0xff, 0xd3, 0x6b, 0x67, 0x3a, 0x7d, 0xdc, 0x15, 0x97, 0xad, 0x44, 0x08, 0xe4, 0x85, 0xc4, 0xad, + 0xb2, 0xc8, 0x73, 0x84, 0x12, 0x49, 0x37, 0x25, 0x23, 0x80, 0x9e, 0x43, 0x12, 0xd0, 0xc7, 0xb3, + 0x52, 0x2e, 0xf9, 0x83, 0xca, 0xc1, 0xe0, 0x39, 0x35, 0xff, 0x13, 0xa8, 0xe9, 0x6b, 0xa6, 0x81, + 0xa6, 0x2e, 0x40, 0xd3, 0xe7, 0x0a, 0x7f, 0xf3, 0x58, 0x66, 0xd3, 0xd9, 0x99, 0x3f, 0x9e, 0x26, + 0xa6, 0x34, 0xc8, 0x1b, 0x4e, 0x71, 0x38, 0x0f, 0xcd, 0xd6, 0xf4, 0xe8, 0x35, 0xf7, 0x5a, 0x64, + 0x09, 0xc7, 0xdc, 0x2c, 0x07, 0x41, 0x0e, 0x6f, 0x87, 0x85, 0x8c, 0x7b, 0x94, 0xc0, 0x1c, 0x2e, + 0x32, 0xf2, 0x91, 0x76, 0x9e, 0xac, 0xca, 0x71, 0x64, 0x3b, 0x8b, 0x98, 0xa9, 0x63, 0xdf, 0x0a, + 0x32, 0x9b, 0xea, 0x4e, 0xd6, 0x39, 0x7e, 0x8c, 0xd0, 0x1a, 0x11, 0x0a, 0xb3, 0x61, 0xac, 0x5b, + 0xad, 0x1c, 0xcd, 0x84, 0x0a, 0x6c, 0x8a, 0x6e, 0xaa, 0x00, 0x1a, 0x9d, 0x7d, 0x87, 0xdc, 0x33, + 0x18, 0x64, 0x35, 0x71, 0x22, 0x6c, 0x4d, 0xd2, 0xc2, 0xac, 0x41, 0xfb, + } serverHelloDone := []byte{0x0e, 0x00, 0x00, 0x00} - clientKeyExchange := []byte{0x10, 0x00, 0x00, 0x21, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54} + clientKeyExchange := []byte{ + 0x10, 0x00, 0x00, 0x21, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, + 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, + 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, + } - finalMsg := append(append(append(append(append(clientHello, serverHello...), serverCertificate...), serverKeyExchange...), serverHelloDone...), clientKeyExchange...) - masterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + finalMsg := append( + append( + append( + append( + append( + clientHello, serverHello..., + ), serverCertificate..., + ), serverKeyExchange..., + ), serverHelloDone..., + ), clientKeyExchange..., + ) + masterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, + 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, + 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, + 0x37, 0x34, 0x4c, + } expectedVerifyData := []byte{0xcf, 0x91, 0x96, 0x26, 0xf1, 0x36, 0x0c, 0x53, 0x6a, 0xaa, 0xd7, 0x3a} verifyData, err := VerifyDataClient(masterSecret, finalMsg, sha256.New) diff --git a/pkg/crypto/selfsign/selfsign.go b/pkg/crypto/selfsign/selfsign.go index 6ef016724..fd238b18a 100644 --- a/pkg/crypto/selfsign/selfsign.go +++ b/pkg/crypto/selfsign/selfsign.go @@ -21,7 +21,7 @@ import ( var errInvalidPrivateKey = errors.New("selfsign: invalid private key type") -// GenerateSelfSigned creates a self-signed certificate +// GenerateSelfSigned creates a self-signed certificate. func GenerateSelfSigned() (tls.Certificate, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -31,7 +31,7 @@ func GenerateSelfSigned() (tls.Certificate, error) { return SelfSign(priv) } -// GenerateSelfSignedWithDNS creates a self-signed certificate +// GenerateSelfSignedWithDNS creates a self-signed certificate. func GenerateSelfSignedWithDNS(cn string, sans ...string) (tls.Certificate, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -41,12 +41,12 @@ func GenerateSelfSignedWithDNS(cn string, sans ...string) (tls.Certificate, erro return WithDNS(priv, cn, sans...) } -// SelfSign creates a self-signed certificate from a elliptic curve key +// SelfSign creates a self-signed certificate from a elliptic curve key. func SelfSign(key crypto.PrivateKey) (tls.Certificate, error) { return WithDNS(key, "self-signed cert") } -// WithDNS creates a self-signed certificate from a elliptic curve key +// WithDNS creates a self-signed certificate from a elliptic curve key. func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, error) { var ( pubKey crypto.PublicKey diff --git a/pkg/crypto/signature/signature.go b/pkg/crypto/signature/signature.go index fec7fba3b..53fb7c952 100644 --- a/pkg/crypto/signature/signature.go +++ b/pkg/crypto/signature/signature.go @@ -8,7 +8,7 @@ package signature // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16 type Algorithm uint16 -// SignatureAlgorithm enums +// SignatureAlgorithm enums. const ( Anonymous Algorithm = 0 RSA Algorithm = 1 @@ -16,7 +16,7 @@ const ( Ed25519 Algorithm = 7 ) -// Algorithms returns all implemented Signature Algorithms +// Algorithms returns all implemented Signature Algorithms. func Algorithms() map[Algorithm]struct{} { return map[Algorithm]struct{}{ Anonymous: {}, diff --git a/pkg/crypto/signaturehash/signaturehash.go b/pkg/crypto/signaturehash/signaturehash.go index 38587768b..b036fdbfc 100644 --- a/pkg/crypto/signaturehash/signaturehash.go +++ b/pkg/crypto/signaturehash/signaturehash.go @@ -25,7 +25,7 @@ type Algorithm struct { Signature signature.Algorithm } -// Algorithms are all the know SignatureHash Algorithms +// Algorithms are all the know SignatureHash Algorithms. func Algorithms() []Algorithm { return []Algorithm{ {hash.SHA256, signature.ECDSA}, @@ -45,6 +45,7 @@ func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algo return ss, nil } } + return Algorithm{}, errNoAvailableSignatureSchemes } diff --git a/pkg/net/net.go b/pkg/net/net.go index e76daf56a..3db604777 100644 --- a/pkg/net/net.go +++ b/pkg/net/net.go @@ -47,6 +47,7 @@ func (p *packetListenerWrapper) Accept() (net.PacketConn, net.Addr, error) { if err != nil { return PacketConnFromConn(c), nil, err } + return PacketConnFromConn(c), c.RemoteAddr(), nil } @@ -73,12 +74,14 @@ type packetConnWrapper struct { // ReadFrom reads from the underlying net.Conn and returns its remote address. func (p *packetConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { n, err := p.conn.Read(b) + return n, p.conn.RemoteAddr(), err } // WriteTo writes to the underlying net.Conn. func (p *packetConnWrapper) WriteTo(b []byte, _ net.Addr) (int, error) { n, err := p.conn.Write(b) + return n, err } diff --git a/pkg/protocol/alert/alert.go b/pkg/protocol/alert/alert.go index 0a9e0a215..8fac65962 100644 --- a/pkg/protocol/alert/alert.go +++ b/pkg/protocol/alert/alert.go @@ -13,10 +13,10 @@ import ( var errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 -// Level is the level of the TLS Alert +// Level is the level of the TLS Alert. type Level byte -// Level enums +// Level enums. const ( Warning Level = 1 Fatal Level = 2 @@ -33,10 +33,10 @@ func (l Level) String() string { } } -// Description is the extended info of the TLS Alert +// Description is the extended info of the TLS Alert. type Description byte -// Description enums +// Description enums. const ( CloseNotify Description = 0 UnexpectedMessage Description = 10 @@ -66,7 +66,7 @@ const ( NoApplicationProtocol Description = 120 ) -func (d Description) String() string { +func (d Description) String() string { //nolint:cyclop switch d { case CloseNotify: return "CloseNotify" @@ -140,17 +140,17 @@ type Alert struct { Description Description } -// ContentType returns the ContentType of this Content +// ContentType returns the ContentType of this Content. func (a Alert) ContentType() protocol.ContentType { return protocol.ContentTypeAlert } -// Marshal returns the encoded alert +// Marshal returns the encoded alert. func (a *Alert) Marshal() ([]byte, error) { return []byte{byte(a.Level), byte(a.Description)}, nil } -// Unmarshal populates the alert from binary data +// Unmarshal populates the alert from binary data. func (a *Alert) Unmarshal(data []byte) error { if len(data) != 2 { return errBufferTooSmall @@ -158,6 +158,7 @@ func (a *Alert) Unmarshal(data []byte) error { a.Level = Level(data[0]) a.Description = Description(data[1]) + return nil } diff --git a/pkg/protocol/application_data.go b/pkg/protocol/application_data.go index f42211511..f478c4231 100644 --- a/pkg/protocol/application_data.go +++ b/pkg/protocol/application_data.go @@ -12,18 +12,19 @@ type ApplicationData struct { Data []byte } -// ContentType returns the ContentType of this content +// ContentType returns the ContentType of this content. func (a ApplicationData) ContentType() ContentType { return ContentTypeApplicationData } -// Marshal encodes the ApplicationData to binary +// Marshal encodes the ApplicationData to binary. func (a *ApplicationData) Marshal() ([]byte, error) { return append([]byte{}, a.Data...), nil } -// Unmarshal populates the ApplicationData from binary +// Unmarshal populates the ApplicationData from binary. func (a *ApplicationData) Unmarshal(data []byte) error { a.Data = append([]byte{}, data...) + return nil } diff --git a/pkg/protocol/change_cipher_spec.go b/pkg/protocol/change_cipher_spec.go index 87f28bc37..4813cd564 100644 --- a/pkg/protocol/change_cipher_spec.go +++ b/pkg/protocol/change_cipher_spec.go @@ -10,17 +10,17 @@ package protocol // https://tools.ietf.org/html/rfc5246#section-7.1 type ChangeCipherSpec struct{} -// ContentType returns the ContentType of this content +// ContentType returns the ContentType of this content. func (c ChangeCipherSpec) ContentType() ContentType { return ContentTypeChangeCipherSpec } -// Marshal encodes the ChangeCipherSpec to binary +// Marshal encodes the ChangeCipherSpec to binary. func (c *ChangeCipherSpec) Marshal() ([]byte, error) { return []byte{0x01}, nil } -// Unmarshal populates the ChangeCipherSpec from binary +// Unmarshal populates the ChangeCipherSpec from binary. func (c *ChangeCipherSpec) Unmarshal(data []byte) error { if len(data) == 1 && data[0] == 0x01 { return nil diff --git a/pkg/protocol/compression_method.go b/pkg/protocol/compression_method.go index 3478ee38c..0fb99a51b 100644 --- a/pkg/protocol/compression_method.go +++ b/pkg/protocol/compression_method.go @@ -3,26 +3,26 @@ package protocol -// CompressionMethodID is the ID for a CompressionMethod +// CompressionMethodID is the ID for a CompressionMethod. type CompressionMethodID byte const ( compressionMethodNull CompressionMethodID = 0 ) -// CompressionMethod represents a TLS Compression Method +// CompressionMethod represents a TLS Compression Method. type CompressionMethod struct { ID CompressionMethodID } -// CompressionMethods returns all supported CompressionMethods +// CompressionMethods returns all supported CompressionMethods. func CompressionMethods() map[CompressionMethodID]*CompressionMethod { return map[CompressionMethodID]*CompressionMethod{ compressionMethodNull: {ID: compressionMethodNull}, } } -// DecodeCompressionMethods the given compression methods +// DecodeCompressionMethods the given compression methods. func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) { if len(buf) < 1 { return nil, errBufferTooSmall @@ -38,14 +38,16 @@ func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) { c = append(c, compressionMethod) } } + return c, nil } -// EncodeCompressionMethods the given compression methods +// EncodeCompressionMethods the given compression methods. func EncodeCompressionMethods(c []*CompressionMethod) []byte { out := []byte{byte(len(c))} for i := len(c); i > 0; i-- { out = append(out, byte(c[i-1].ID)) } + return out } diff --git a/pkg/protocol/content.go b/pkg/protocol/content.go index 154005e2c..9b6daa51f 100644 --- a/pkg/protocol/content.go +++ b/pkg/protocol/content.go @@ -8,7 +8,7 @@ package protocol // https://tools.ietf.org/html/rfc4346#section-6.2.1 type ContentType uint8 -// ContentType enums +// ContentType enums. const ( ContentTypeChangeCipherSpec ContentType = 20 ContentTypeAlert ContentType = 21 @@ -17,7 +17,7 @@ const ( ContentTypeConnectionID ContentType = 25 ) -// Content is the top level distinguisher for a DTLS Datagram +// Content is the top level distinguisher for a DTLS Datagram. type Content interface { ContentType() ContentType Marshal() ([]byte, error) diff --git a/pkg/protocol/errors.go b/pkg/protocol/errors.go index d87aff7fb..dc091bff3 100644 --- a/pkg/protocol/errors.go +++ b/pkg/protocol/errors.go @@ -20,7 +20,8 @@ type FatalError struct { Err error } -// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available. +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError struct { Err error @@ -41,10 +42,10 @@ type HandshakeError struct { Err error } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*FatalError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*FatalError) Temporary() bool { return false } // Unwrap implements Go1.13 error unwrapper. @@ -52,10 +53,10 @@ func (e *FatalError) Unwrap() error { return e.Err } func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*InternalError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*InternalError) Temporary() bool { return false } // Unwrap implements Go1.13 error unwrapper. @@ -63,10 +64,10 @@ func (e *InternalError) Unwrap() error { return e.Err } func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*TemporaryError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*TemporaryError) Temporary() bool { return true } // Unwrap implements Go1.13 error unwrapper. @@ -74,10 +75,10 @@ func (e *TemporaryError) Unwrap() error { return e.Err } func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*TimeoutError) Timeout() bool { return true } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*TimeoutError) Temporary() bool { return true } // Unwrap implements Go1.13 error unwrapper. @@ -85,21 +86,23 @@ func (e *TimeoutError) Unwrap() error { return e.Err } func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (e *HandshakeError) Timeout() bool { var netErr net.Error if errors.As(e.Err, &netErr) { return netErr.Timeout() } + return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (e *HandshakeError) Temporary() bool { var netErr net.Error if errors.As(e.Err, &netErr) { return netErr.Temporary() //nolint } + return false } diff --git a/pkg/protocol/extension/alpn.go b/pkg/protocol/extension/alpn.go index e780dc9e1..719428601 100644 --- a/pkg/protocol/extension/alpn.go +++ b/pkg/protocol/extension/alpn.go @@ -15,16 +15,16 @@ type ALPN struct { ProtocolNameList []string } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (a ALPN) TypeValue() TypeValue { return ALPNTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (a *ALPN) Marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(uint16(a.TypeValue())) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + var builder cryptobyte.Builder + builder.AddUint16(uint16(a.TypeValue())) + builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, proto := range a.ProtocolNameList { p := proto // Satisfy range scope lint @@ -34,10 +34,11 @@ func (a *ALPN) Marshal() ([]byte, error) { } }) }) - return b.Bytes() + + return builder.Bytes() } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (a *ALPN) Unmarshal(data []byte) error { val := cryptobyte.String(data) @@ -61,10 +62,11 @@ func (a *ALPN) Unmarshal(data []byte) error { } a.ProtocolNameList = append(a.ProtocolNameList, string(proto)) } + return nil } -// ALPNProtocolSelection negotiates a shared protocol according to #3.2 of rfc7301 +// ALPNProtocolSelection negotiates a shared protocol according to #3.2 of rfc7301. func ALPNProtocolSelection(supportedProtocols, peerSupportedProtocols []string) (string, error) { if len(supportedProtocols) == 0 || len(peerSupportedProtocols) == 0 { return "", nil @@ -76,5 +78,6 @@ func ALPNProtocolSelection(supportedProtocols, peerSupportedProtocols []string) } } } + return "", errALPNNoAppProto } diff --git a/pkg/protocol/extension/alpn_test.go b/pkg/protocol/extension/alpn_test.go index 6b12af0f7..11468fac0 100644 --- a/pkg/protocol/extension/alpn_test.go +++ b/pkg/protocol/extension/alpn_test.go @@ -31,22 +31,22 @@ func TestALPN(t *testing.T) { } func TestALPNProtocolSelection(t *testing.T) { - s, err := ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{"spd/1"}) + selectedProtocol, err := ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{"spd/1"}) if err != nil { t.Fatal(err) } - if s != "spd/1" { - t.Errorf("expected: spd/1, got: %v", s) + if selectedProtocol != "spd/1" { + t.Errorf("expected: spd/1, got: %v", selectedProtocol) } _, err = ALPNProtocolSelection([]string{"http/1.1"}, []string{"spd/1"}) if !errors.Is(err, errALPNNoAppProto) { t.Fatal("expected to fail negotiating an application protocol") } - s, err = ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{}) + selectedProtocol, err = ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{}) if err != nil { t.Fatal(err) } - if s != "" { - t.Errorf("expected not to negotiate a protocol, got: %v", s) + if selectedProtocol != "" { + t.Errorf("expected not to negotiate a protocol, got: %v", selectedProtocol) } } diff --git a/pkg/protocol/extension/connection_id.go b/pkg/protocol/extension/connection_id.go index b3fe1640f..6c8a7f566 100644 --- a/pkg/protocol/extension/connection_id.go +++ b/pkg/protocol/extension/connection_id.go @@ -18,12 +18,12 @@ type ConnectionID struct { CID []byte // variable length } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (c ConnectionID) TypeValue() TypeValue { return ConnectionIDTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (c *ConnectionID) Marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(uint16(c.TypeValue())) @@ -32,10 +32,11 @@ func (c *ConnectionID) Marshal() ([]byte, error) { b.AddBytes(c.CID) }) }) + return b.Bytes() } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (c *ConnectionID) Unmarshal(data []byte) error { val := cryptobyte.String(data) var extension uint16 @@ -55,5 +56,6 @@ func (c *ConnectionID) Unmarshal(data []byte) error { if !cid.CopyBytes(c.CID) { return errInvalidCIDFormat } + return nil } diff --git a/pkg/protocol/extension/errors.go b/pkg/protocol/extension/errors.go index 5999c96fe..424ae5b1a 100644 --- a/pkg/protocol/extension/errors.go +++ b/pkg/protocol/extension/errors.go @@ -10,13 +10,29 @@ import ( ) var ( - // ErrALPNInvalidFormat is raised when the ALPN format is invalid - ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113 - errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113 - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113 - errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113 - errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113 - errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 - errMasterKeyIdentifierTooLarge = &protocol.FatalError{Err: errors.New("master key identifier is over 255 bytes")} //nolint:goerr113 + // ErrALPNInvalidFormat is raised when the ALPN format is invalid. + ErrALPNInvalidFormat = &protocol.FatalError{ + Err: errors.New("invalid alpn format"), //nolint:goerr113 + } + errALPNNoAppProto = &protocol.FatalError{ + Err: errors.New("no application protocol"), //nolint:goerr113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:goerr113 + } + errInvalidExtensionType = &protocol.FatalError{ + Err: errors.New("invalid extension type"), //nolint:goerr113 + } + errInvalidSNIFormat = &protocol.FatalError{ + Err: errors.New("invalid server name format"), //nolint:goerr113 + } + errInvalidCIDFormat = &protocol.FatalError{ + Err: errors.New("invalid connection ID format"), //nolint:goerr113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:goerr113 + } + errMasterKeyIdentifierTooLarge = &protocol.FatalError{ + Err: errors.New("master key identifier is over 255 bytes"), //nolint:goerr113 + } ) diff --git a/pkg/protocol/extension/extension.go b/pkg/protocol/extension/extension.go index e4df859f8..ba82beea3 100644 --- a/pkg/protocol/extension/extension.go +++ b/pkg/protocol/extension/extension.go @@ -11,7 +11,7 @@ import "encoding/binary" // https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml type TypeValue uint16 -// TypeValue constants +// TypeValue constants. const ( ServerNameTypeValue TypeValue = 0 SupportedEllipticCurvesTypeValue TypeValue = 10 @@ -24,15 +24,15 @@ const ( RenegotiationInfoTypeValue TypeValue = 65281 ) -// Extension represents a single TLS extension +// Extension represents a single TLS extension. type Extension interface { Marshal() ([]byte, error) Unmarshal(data []byte) error TypeValue() TypeValue } -// Unmarshal many extensions at once -func Unmarshal(buf []byte) ([]Extension, error) { +// Unmarshal many extensions at once. +func Unmarshal(buf []byte) ([]Extension, error) { //nolint:cyclop switch { case len(buf) == 0: return []Extension{}, nil @@ -52,6 +52,7 @@ func Unmarshal(buf []byte) ([]Extension, error) { return err } extensions = append(extensions, e) + return nil } @@ -90,10 +91,11 @@ func Unmarshal(buf []byte) ([]Extension, error) { extensionLength := binary.BigEndian.Uint16(buf[offset+2:]) offset += (4 + int(extensionLength)) } + return extensions, nil } -// Marshal many extensions at once +// Marshal many extensions at once. func Marshal(e []Extension) ([]byte, error) { extensions := []byte{} for _, e := range e { @@ -104,6 +106,7 @@ func Marshal(e []Extension) ([]byte, error) { extensions = append(extensions, raw...) } out := []byte{0x00, 0x00} - binary.BigEndian.PutUint16(out, uint16(len(extensions))) + binary.BigEndian.PutUint16(out, uint16(len(extensions))) //nolint:gosec // G115 + return append(out, extensions...), nil } diff --git a/pkg/protocol/extension/renegotiation_info.go b/pkg/protocol/extension/renegotiation_info.go index c5092a7db..57432fd0b 100644 --- a/pkg/protocol/extension/renegotiation_info.go +++ b/pkg/protocol/extension/renegotiation_info.go @@ -17,22 +17,23 @@ type RenegotiationInfo struct { RenegotiatedConnection uint8 } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (r RenegotiationInfo) TypeValue() TypeValue { return RenegotiationInfoTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (r *RenegotiationInfo) Marshal() ([]byte, error) { out := make([]byte, renegotiationInfoHeaderSize) binary.BigEndian.PutUint16(out, uint16(r.TypeValue())) binary.BigEndian.PutUint16(out[2:], uint16(1)) // length out[4] = r.RenegotiatedConnection + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (r *RenegotiationInfo) Unmarshal(data []byte) error { if len(data) < renegotiationInfoHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/renegotiation_info_test.go b/pkg/protocol/extension/renegotiation_info_test.go index 63b9609d1..252144332 100644 --- a/pkg/protocol/extension/renegotiation_info_test.go +++ b/pkg/protocol/extension/renegotiation_info_test.go @@ -20,6 +20,9 @@ func TestRenegotiationInfo(t *testing.T) { } if newExtension.RenegotiatedConnection != extension.RenegotiatedConnection { - t.Errorf("extensionRenegotiationInfo marshal: got %d expected %d", newExtension.RenegotiatedConnection, extension.RenegotiatedConnection) + t.Errorf( + "extensionRenegotiationInfo marshal: got %d expected %d", + newExtension.RenegotiatedConnection, extension.RenegotiatedConnection, + ) } } diff --git a/pkg/protocol/extension/server_name.go b/pkg/protocol/extension/server_name.go index 183e08e6e..31e6327d3 100644 --- a/pkg/protocol/extension/server_name.go +++ b/pkg/protocol/extension/server_name.go @@ -20,12 +20,12 @@ type ServerName struct { ServerName string } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s ServerName) TypeValue() TypeValue { return ServerNameTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *ServerName) Marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(uint16(s.TypeValue())) @@ -37,11 +37,12 @@ func (s *ServerName) Marshal() ([]byte, error) { }) }) }) + return b.Bytes() } -// Unmarshal populates the extension from encoded data -func (s *ServerName) Unmarshal(data []byte) error { +// Unmarshal populates the extension from encoded data. +func (s *ServerName) Unmarshal(data []byte) error { //nolint:cyclop val := cryptobyte.String(data) var extension uint16 val.ReadUint16(&extension) @@ -77,5 +78,6 @@ func (s *ServerName) Unmarshal(data []byte) error { return errInvalidSNIFormat } } + return nil } diff --git a/pkg/protocol/extension/supported_elliptic_curves.go b/pkg/protocol/extension/supported_elliptic_curves.go index e83b5ccb8..e3e87634b 100644 --- a/pkg/protocol/extension/supported_elliptic_curves.go +++ b/pkg/protocol/extension/supported_elliptic_curves.go @@ -21,28 +21,28 @@ type SupportedEllipticCurves struct { EllipticCurves []elliptic.Curve } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedEllipticCurves) TypeValue() TypeValue { return SupportedEllipticCurvesTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedEllipticCurves) Marshal() ([]byte, error) { out := make([]byte, supportedGroupsHeaderSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2))) - binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2)) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2)) //nolint:gosec // G115 for _, v := range s.EllipticCurves { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) } return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedEllipticCurves) Unmarshal(data []byte) error { if len(data) <= supportedGroupsHeaderSize { return errBufferTooSmall @@ -61,5 +61,6 @@ func (s *SupportedEllipticCurves) Unmarshal(data []byte) error { s.EllipticCurves = append(s.EllipticCurves, supportedGroupID) } } + return nil } diff --git a/pkg/protocol/extension/supported_point_formats.go b/pkg/protocol/extension/supported_point_formats.go index ec877aff2..77dc4fd50 100644 --- a/pkg/protocol/extension/supported_point_formats.go +++ b/pkg/protocol/extension/supported_point_formats.go @@ -21,26 +21,27 @@ type SupportedPointFormats struct { PointFormats []elliptic.CurvePointFormat } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedPointFormats) TypeValue() TypeValue { return SupportedPointFormatsTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedPointFormats) Marshal() ([]byte, error) { out := make([]byte, supportedPointFormatsSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) + binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) //nolint:gosec // G115 out[4] = byte(len(s.PointFormats)) for _, v := range s.PointFormats { - out = append(out, byte(v)) + out = append(out, byte(v)) //nolint:makezero // todo: fix } + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedPointFormats) Unmarshal(data []byte) error { if len(data) <= supportedPointFormatsSize { return errBufferTooSmall @@ -63,5 +64,6 @@ func (s *SupportedPointFormats) Unmarshal(data []byte) error { default: } } + return nil } diff --git a/pkg/protocol/extension/supported_point_formats_test.go b/pkg/protocol/extension/supported_point_formats_test.go index eecbd5c82..f57bdd0ca 100644 --- a/pkg/protocol/extension/supported_point_formats_test.go +++ b/pkg/protocol/extension/supported_point_formats_test.go @@ -27,6 +27,9 @@ func TestExtensionSupportedPointFormats(t *testing.T) { if err := roundtrip.Unmarshal(raw); err != nil { t.Error(err) } else if !reflect.DeepEqual(roundtrip, parsedExtensionSupportedPointFormats) { - t.Errorf("extensionSupportedPointFormats unmarshal: got %#v, want %#v", roundtrip, parsedExtensionSupportedPointFormats) + t.Errorf( + "extensionSupportedPointFormats unmarshal: got %#v, want %#v", + roundtrip, parsedExtensionSupportedPointFormats, + ) } } diff --git a/pkg/protocol/extension/supported_signature_algorithms.go b/pkg/protocol/extension/supported_signature_algorithms.go index 396b9ae38..e7ad0d422 100644 --- a/pkg/protocol/extension/supported_signature_algorithms.go +++ b/pkg/protocol/extension/supported_signature_algorithms.go @@ -23,20 +23,20 @@ type SupportedSignatureAlgorithms struct { SignatureHashAlgorithms []signaturehash.Algorithm } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedSignatureAlgorithms) TypeValue() TypeValue { return SupportedSignatureAlgorithmsTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) { out := make([]byte, supportedSignatureAlgorithmsHeaderSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.SignatureHashAlgorithms)*2))) - binary.BigEndian.PutUint16(out[4:], uint16(len(s.SignatureHashAlgorithms)*2)) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.SignatureHashAlgorithms)*2))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(out[4:], uint16(len(s.SignatureHashAlgorithms)*2)) //nolint:gosec // G115 for _, v := range s.SignatureHashAlgorithms { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix out[len(out)-2] = byte(v.Hash) out[len(out)-1] = byte(v.Signature) } @@ -44,7 +44,7 @@ func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) { return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedSignatureAlgorithms) Unmarshal(data []byte) error { if len(data) <= supportedSignatureAlgorithmsHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/supported_signature_algorithms_test.go b/pkg/protocol/extension/supported_signature_algorithms_test.go index 80b5692e3..653d4b558 100644 --- a/pkg/protocol/extension/supported_signature_algorithms_test.go +++ b/pkg/protocol/extension/supported_signature_algorithms_test.go @@ -33,13 +33,19 @@ func TestExtensionSupportedSignatureAlgorithms(t *testing.T) { if err != nil { t.Fatal(err) } else if !reflect.DeepEqual(raw, rawExtensionSupportedSignatureAlgorithms) { - t.Fatalf("extensionSupportedSignatureAlgorithms marshal: got %#v, want %#v", raw, rawExtensionSupportedSignatureAlgorithms) + t.Fatalf( + "extensionSupportedSignatureAlgorithms marshal: got %#v, want %#v", + raw, rawExtensionSupportedSignatureAlgorithms, + ) } roundtrip := &SupportedSignatureAlgorithms{} if err := roundtrip.Unmarshal(raw); err != nil { t.Error(err) } else if !reflect.DeepEqual(roundtrip, parsedExtensionSupportedSignatureAlgorithms) { - t.Errorf("extensionSupportedSignatureAlgorithms unmarshal: got %#v, want %#v", roundtrip, parsedExtensionSupportedSignatureAlgorithms) + t.Errorf( + "extensionSupportedSignatureAlgorithms unmarshal: got %#v, want %#v", + roundtrip, parsedExtensionSupportedSignatureAlgorithms, + ) } } diff --git a/pkg/protocol/extension/use_master_secret.go b/pkg/protocol/extension/use_master_secret.go index d0b70cafb..fcf5dd289 100644 --- a/pkg/protocol/extension/use_master_secret.go +++ b/pkg/protocol/extension/use_master_secret.go @@ -16,12 +16,12 @@ type UseExtendedMasterSecret struct { Supported bool } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (u UseExtendedMasterSecret) TypeValue() TypeValue { return UseExtendedMasterSecretTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) { if !u.Supported { return []byte{}, nil @@ -31,10 +31,11 @@ func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) { binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) binary.BigEndian.PutUint16(out[2:], uint16(0)) // length + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (u *UseExtendedMasterSecret) Unmarshal(data []byte) error { if len(data) < useExtendedMasterSecretHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/use_srtp.go b/pkg/protocol/extension/use_srtp.go index 6d5f54b23..4e0410cae 100644 --- a/pkg/protocol/extension/use_srtp.go +++ b/pkg/protocol/extension/use_srtp.go @@ -20,34 +20,38 @@ type UseSRTP struct { MasterKeyIdentifier []byte } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (u UseSRTP) TypeValue() TypeValue { return UseSRTPTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (u *UseSRTP) Marshal() ([]byte, error) { out := make([]byte, useSRTPHeaderSize) binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier))) - binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) + //nolint:gosec // G115 + binary.BigEndian.PutUint16( + out[2:], + uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier)), + ) + binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) //nolint:gosec // G115 for _, v := range u.ProtectionProfiles { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) } if len(u.MasterKeyIdentifier) > 255 { return nil, errMasterKeyIdentifierTooLarge } - out = append(out, byte(len(u.MasterKeyIdentifier))) - out = append(out, u.MasterKeyIdentifier...) + out = append(out, byte(len(u.MasterKeyIdentifier))) //nolint:makezero // todo: fix + out = append(out, u.MasterKeyIdentifier...) //nolint:makezero // todo: fix return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (u *UseSRTP) Unmarshal(data []byte) error { if len(data) <= useSRTPHeaderSize { return errBufferTooSmall @@ -73,7 +77,10 @@ func (u *UseSRTP) Unmarshal(data []byte) error { return errLengthMismatch } - u.MasterKeyIdentifier = append([]byte{}, data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]...) + u.MasterKeyIdentifier = append( + []byte{}, + data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]..., + ) return nil } diff --git a/pkg/protocol/extension/use_srtp_test.go b/pkg/protocol/extension/use_srtp_test.go index 36e284cd4..c88c61c15 100644 --- a/pkg/protocol/extension/use_srtp_test.go +++ b/pkg/protocol/extension/use_srtp_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func TestExtensionUseSRTP(t *testing.T) { +func TestExtensionUseSRTP(t *testing.T) { //nolint:cyclop t.Run("No MasterKeyIdentifier", func(t *testing.T) { rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} parsedUseSRTP := &UseSRTP{ @@ -57,11 +57,15 @@ func TestExtensionUseSRTP(t *testing.T) { t.Run("Invalid Lengths", func(t *testing.T) { unmarshaled := &UseSRTP{} - if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x04, 0x00, 0x01, 0x00}); !errors.Is(errLengthMismatch, err) { + if err := unmarshaled.Unmarshal( + []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x04, 0x00, 0x01, 0x00}, + ); !errors.Is(errLengthMismatch, err) { t.Error(err) } - if err := unmarshaled.Unmarshal([]byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x01}); !errors.Is(errLengthMismatch, err) { + if err := unmarshaled.Unmarshal( + []byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x01}, + ); !errors.Is(errLengthMismatch, err) { t.Error(err) } diff --git a/pkg/protocol/handshake/cipher_suite.go b/pkg/protocol/handshake/cipher_suite.go index b29629717..49d2b7407 100644 --- a/pkg/protocol/handshake/cipher_suite.go +++ b/pkg/protocol/handshake/cipher_suite.go @@ -18,15 +18,17 @@ func decodeCipherSuiteIDs(buf []byte) ([]uint16, error) { rtrn[i] = binary.BigEndian.Uint16(buf[(i*2)+2:]) } + return rtrn, nil } func encodeCipherSuiteIDs(cipherSuiteIDs []uint16) []byte { out := []byte{0x00, 0x00} - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2)) //nolint:gosec // G115 for _, id := range cipherSuiteIDs { out = append(out, []byte{0x00, 0x00}...) binary.BigEndian.PutUint16(out[len(out)-2:], id) } + return out } diff --git a/pkg/protocol/handshake/errors.go b/pkg/protocol/handshake/errors.go index 4f007198f..20794f78c 100644 --- a/pkg/protocol/handshake/errors.go +++ b/pkg/protocol/handshake/errors.go @@ -9,20 +9,48 @@ import ( "github.com/pion/dtls/v3/pkg/protocol" ) -// Typed errors +// Typed errors. var ( - errUnableToMarshalFragmented = &protocol.InternalError{Err: errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113 - errHandshakeMessageUnset = &protocol.InternalError{Err: errors.New("handshake message unset, unable to marshal")} //nolint:goerr113 - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 - errInvalidClientKeyExchange = &protocol.FatalError{Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113 - errInvalidHashAlgorithm = &protocol.FatalError{Err: errors.New("invalid hash algorithm")} //nolint:goerr113 - errInvalidSignatureAlgorithm = &protocol.FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 - errCookieTooLong = &protocol.FatalError{Err: errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113 - errInvalidEllipticCurveType = &protocol.FatalError{Err: errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113 - errInvalidNamedCurve = &protocol.FatalError{Err: errors.New("invalid named curve")} //nolint:goerr113 - errCipherSuiteUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113 - errCompressionMethodUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a compression method")} //nolint:goerr113 - errInvalidCompressionMethod = &protocol.FatalError{Err: errors.New("invalid or unknown compression method")} //nolint:goerr113 - errNotImplemented = &protocol.InternalError{Err: errors.New("feature has not been implemented yet")} //nolint:goerr113 + errUnableToMarshalFragmented = &protocol.InternalError{ + Err: errors.New("unable to marshal fragmented handshakes"), //nolint:err113 + } + errHandshakeMessageUnset = &protocol.InternalError{ + Err: errors.New("handshake message unset, unable to marshal"), //nolint:err113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:err113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:err113 + } + errInvalidClientKeyExchange = &protocol.FatalError{ + Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity"), //nolint:err113 + } + errInvalidHashAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid hash algorithm"), //nolint:err113 + } + errInvalidSignatureAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid signature algorithm"), //nolint:err113 + } + errCookieTooLong = &protocol.FatalError{ + Err: errors.New("cookie must not be longer then 255 bytes"), //nolint:err113 + } + errInvalidEllipticCurveType = &protocol.FatalError{ + Err: errors.New("invalid or unknown elliptic curve type"), //nolint:err113 + } + errInvalidNamedCurve = &protocol.FatalError{ + Err: errors.New("invalid named curve"), //nolint:err113 + } + errCipherSuiteUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a cipher suite"), //nolint:err113 + } + errCompressionMethodUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a compression method"), //nolint:err113 + } + errInvalidCompressionMethod = &protocol.FatalError{ + Err: errors.New("invalid or unknown compression method"), //nolint:err113 + } + errNotImplemented = &protocol.InternalError{ + Err: errors.New("feature has not been implemented yet"), //nolint:err113 + } ) diff --git a/pkg/protocol/handshake/handshake.go b/pkg/protocol/handshake/handshake.go index 6f0eb9c71..9c6187711 100644 --- a/pkg/protocol/handshake/handshake.go +++ b/pkg/protocol/handshake/handshake.go @@ -14,7 +14,7 @@ import ( // https://tools.ietf.org/html/rfc5246#section-7.4 type Type uint8 -// Types of DTLS Handshake messages we know about +// Types of DTLS Handshake messages we know about. const ( TypeHelloRequest Type = 0 TypeClientHello Type = 1 @@ -29,8 +29,8 @@ const ( TypeFinished Type = 20 ) -// String returns the string representation of this type -func (t Type) String() string { +// String returns the string representation of this type. +func (t Type) String() string { //nolint:cyclop switch t { case TypeHelloRequest: return "HelloRequest" @@ -55,10 +55,11 @@ func (t Type) String() string { case TypeFinished: return "Finished" } + return "" } -// Message is the body of a Handshake datagram +// Message is the body of a Handshake datagram. type Message interface { Marshal() ([]byte, error) Unmarshal(data []byte) error @@ -78,12 +79,12 @@ type Handshake struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// ContentType returns what kind of content this message is carying +// ContentType returns what kind of content this message is carying. func (h Handshake) ContentType() protocol.ContentType { return protocol.ContentTypeHandshake } -// Marshal encodes a handshake into a binary message +// Marshal encodes a handshake into a binary message. func (h *Handshake) Marshal() ([]byte, error) { if h.Message == nil { return nil, errHandshakeMessageUnset @@ -96,7 +97,7 @@ func (h *Handshake) Marshal() ([]byte, error) { return nil, err } - h.Header.Length = uint32(len(msg)) + h.Header.Length = uint32(len(msg)) //nolint:gosec // G115 h.Header.FragmentLength = h.Header.Length h.Header.Type = h.Message.Type() header, err := h.Header.Marshal() @@ -107,14 +108,14 @@ func (h *Handshake) Marshal() ([]byte, error) { return append(header, msg...), nil } -// Unmarshal decodes a handshake from a binary message -func (h *Handshake) Unmarshal(data []byte) error { +// Unmarshal decodes a handshake from a binary message. +func (h *Handshake) Unmarshal(data []byte) error { //nolint:cyclop if err := h.Header.Unmarshal(data); err != nil { return err } reportedLen := util.BigEndianUint24(data[1:]) - if uint32(len(data)-HeaderLength) != reportedLen { + if uint32(len(data)-HeaderLength) != reportedLen { //nolint:gosec // G115 return errLengthMismatch } else if reportedLen != h.Header.FragmentLength { return errLengthMismatch @@ -146,5 +147,6 @@ func (h *Handshake) Unmarshal(data []byte) error { default: return errNotImplemented } + return h.Message.Unmarshal(data[HeaderLength:]) } diff --git a/pkg/protocol/handshake/header.go b/pkg/protocol/handshake/header.go index 619fd2bdb..4e909de54 100644 --- a/pkg/protocol/handshake/header.go +++ b/pkg/protocol/handshake/header.go @@ -10,7 +10,7 @@ import ( ) // HeaderLength msg_len for Handshake messages assumes an extra -// 12 bytes for sequence, fragment and version information vs TLS +// 12 bytes for sequence, fragment and version information vs TLS. const HeaderLength = 12 // Header is the static first 12 bytes of each RecordLayer @@ -26,7 +26,7 @@ type Header struct { FragmentLength uint32 // uint24 in spec } -// Marshal encodes the Header +// Marshal encodes the Header. func (h *Header) Marshal() ([]byte, error) { out := make([]byte, HeaderLength) @@ -35,10 +35,11 @@ func (h *Header) Marshal() ([]byte, error) { binary.BigEndian.PutUint16(out[4:], h.MessageSequence) util.PutBigEndianUint24(out[6:], h.FragmentOffset) util.PutBigEndianUint24(out[9:], h.FragmentLength) + return out, nil } -// Unmarshal populates the header from encoded data +// Unmarshal populates the header from encoded data. func (h *Header) Unmarshal(data []byte) error { if len(data) < HeaderLength { return errBufferTooSmall @@ -49,5 +50,6 @@ func (h *Header) Unmarshal(data []byte) error { h.MessageSequence = binary.BigEndian.Uint16(data[4:]) h.FragmentOffset = util.BigEndianUint24(data[6:]) h.FragmentLength = util.BigEndianUint24(data[9:]) + return nil } diff --git a/pkg/protocol/handshake/message_certificate.go b/pkg/protocol/handshake/message_certificate.go index 54b0bfddb..27d2ea99a 100644 --- a/pkg/protocol/handshake/message_certificate.go +++ b/pkg/protocol/handshake/message_certificate.go @@ -15,7 +15,7 @@ type MessageCertificate struct { Certificate [][]byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificate) Type() Type { return TypeCertificate } @@ -24,31 +24,36 @@ const ( handshakeMessageCertificateLengthFieldSize = 3 ) -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificate) Marshal() ([]byte, error) { out := make([]byte, handshakeMessageCertificateLengthFieldSize) for _, r := range m.Certificate { // Certificate Length + //nolint:makezero // todo: fix out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...) + //nolint:gosec // G115 util.PutBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r))) // Certificate body - out = append(out, append([]byte{}, r...)...) + out = append(out, append([]byte{}, r...)...) //nolint:makezero // todo: fix } // Total Payload Size - util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:]))) + util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:]))) //nolint:gosec //G115 + return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageCertificate) Unmarshal(data []byte) error { if len(data) < handshakeMessageCertificateLengthFieldSize { return errBufferTooSmall } - if certificateBodyLen := int(util.BigEndianUint24(data)); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) { + if certificateBodyLen := int(util.BigEndianUint24( + data, + )); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) { return errLengthMismatch } diff --git a/pkg/protocol/handshake/message_certificate_request.go b/pkg/protocol/handshake/message_certificate_request.go index 5c3ac63c5..28dabf35a 100644 --- a/pkg/protocol/handshake/message_certificate_request.go +++ b/pkg/protocol/handshake/message_certificate_request.go @@ -31,12 +31,12 @@ const ( messageCertificateRequestMinLength = 5 ) -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificateRequest) Type() Type { return TypeCertificateRequest } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificateRequest) Marshal() ([]byte, error) { out := []byte{byte(len(m.CertificateTypes))} for _, v := range m.CertificateTypes { @@ -44,7 +44,7 @@ func (m *MessageCertificateRequest) Marshal() ([]byte, error) { } out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2)) //nolint:gosec //G115 for _, v := range m.SignatureHashAlgorithms { out = append(out, byte(v.Hash)) out = append(out, byte(v.Signature)) @@ -56,19 +56,20 @@ func (m *MessageCertificateRequest) Marshal() ([]byte, error) { casLength += len(ca) + 2 } out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(casLength)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(casLength)) //nolint:gosec //G115 if casLength > 0 { for _, ca := range m.CertificateAuthoritiesNames { out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(ca))) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(ca))) //nolint:gosec //G115 out = append(out, ca...) } } + return out, nil } -// Unmarshal populates the message from encoded data -func (m *MessageCertificateRequest) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageCertificateRequest) Unmarshal(data []byte) error { //nolint:cyclop if len(data) < messageCertificateRequestMinLength { return errBufferTooSmall } diff --git a/pkg/protocol/handshake/message_certificate_test.go b/pkg/protocol/handshake/message_certificate_test.go index 760d5efdf..8abe42578 100644 --- a/pkg/protocol/handshake/message_certificate_test.go +++ b/pkg/protocol/handshake/message_certificate_test.go @@ -60,21 +60,21 @@ func TestHandshakeMessageCertificate(t *testing.T) { Version: 1, } - c := &MessageCertificate{} - if err := c.Unmarshal(rawCertificate); err != nil { + certMessage := &MessageCertificate{} + if err := certMessage.Unmarshal(rawCertificate); err != nil { t.Error(err) } else { - certificate, err := x509.ParseCertificate(c.Certificate[0]) + certificate, err := x509.ParseCertificate(certMessage.Certificate[0]) if err != nil { t.Error(err) } copyCertificatePrivateMembers(certificate, parsedCertificate) if !reflect.DeepEqual(certificate, parsedCertificate) { - t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, parsedCertificate) + t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", certMessage, parsedCertificate) } } - raw, err := c.Marshal() + raw, err := certMessage.Marshal() if err != nil { t.Error(err) } else if !reflect.DeepEqual(raw, rawCertificate) { diff --git a/pkg/protocol/handshake/message_certificate_verify.go b/pkg/protocol/handshake/message_certificate_verify.go index 9d09b31ca..d10ffa035 100644 --- a/pkg/protocol/handshake/message_certificate_verify.go +++ b/pkg/protocol/handshake/message_certificate_verify.go @@ -22,23 +22,24 @@ type MessageCertificateVerify struct { const handshakeMessageCertificateVerifyMinLength = 4 -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificateVerify) Type() Type { return TypeCertificateVerify } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificateVerify) Marshal() ([]byte, error) { out := make([]byte, 1+1+2+len(m.Signature)) out[0] = byte(m.HashAlgorithm) out[1] = byte(m.SignatureAlgorithm) - binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature))) + binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature))) //nolint:gosec // G115 copy(out[4:], m.Signature) + return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageCertificateVerify) Unmarshal(data []byte) error { if len(data) < handshakeMessageCertificateVerifyMinLength { return errBufferTooSmall @@ -60,5 +61,6 @@ func (m *MessageCertificateVerify) Unmarshal(data []byte) error { } m.Signature = append([]byte{}, data[4:]...) + return nil } diff --git a/pkg/protocol/handshake/message_client_hello.go b/pkg/protocol/handshake/message_client_hello.go index c651718e5..e7aa5e397 100644 --- a/pkg/protocol/handshake/message_client_hello.go +++ b/pkg/protocol/handshake/message_client_hello.go @@ -31,12 +31,12 @@ type MessageClientHello struct { const handshakeMessageClientHelloVariableWidthStart = 34 -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageClientHello) Type() Type { return TypeClientHello } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageClientHello) Marshal() ([]byte, error) { if len(m.Cookie) > 255 { return nil, errCookieTooLong @@ -49,24 +49,24 @@ func (m *MessageClientHello) Marshal() ([]byte, error) { rand := m.Random.MarshalFixed() copy(out[2:], rand[:]) - out = append(out, byte(len(m.SessionID))) - out = append(out, m.SessionID...) + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix - out = append(out, byte(len(m.Cookie))) - out = append(out, m.Cookie...) - out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...) - out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...) + out = append(out, byte(len(m.Cookie))) //nolint:makezero // todo: fix + out = append(out, m.Cookie...) //nolint:makezero // todo: fix + out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...) //nolint:makezero // todo: fix + out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...) //nolint:makezero // todo: fix extensions, err := extension.Marshal(m.Extensions) if err != nil { return nil, err } - return append(out, extensions...), nil + return append(out, extensions...), nil //nolint:makezero // todo: fix } -// Unmarshal populates the message from encoded data -func (m *MessageClientHello) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageClientHello) Unmarshal(data []byte) error { //nolint:cyclop if len(data) < 2+RandomLength { return errBufferTooSmall } @@ -137,5 +137,6 @@ func (m *MessageClientHello) Unmarshal(data []byte) error { return err } m.Extensions = extensions + return nil } diff --git a/pkg/protocol/handshake/message_client_hello_test.go b/pkg/protocol/handshake/message_client_hello_test.go index d46567691..f32fe1dff 100644 --- a/pkg/protocol/handshake/message_client_hello_test.go +++ b/pkg/protocol/handshake/message_client_hello_test.go @@ -26,10 +26,16 @@ func TestHandshakeMessageClientHello(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: Random{ GMTUnixTime: time.Unix(3056586332, 0), - RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b}, + RandomBytes: [28]byte{ + 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, + 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, + }, }, SessionID: []byte{}, - Cookie: []byte{0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8}, + Cookie: []byte{ + 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, + 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, + }, CipherSuiteIDs: []uint16{ 0xc02b, 0xc00a, diff --git a/pkg/protocol/handshake/message_client_key_exchange.go b/pkg/protocol/handshake/message_client_key_exchange.go index 626361c6c..60361a94a 100644 --- a/pkg/protocol/handshake/message_client_key_exchange.go +++ b/pkg/protocol/handshake/message_client_key_exchange.go @@ -24,12 +24,12 @@ type MessageClientKeyExchange struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageClientKeyExchange) Type() Type { return TypeClientKeyExchange } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { if m.IdentityHint == nil && m.PublicKey == nil { return nil, errInvalidClientKeyExchange @@ -37,7 +37,7 @@ func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { if m.IdentityHint != nil { out = append([]byte{0x00, 0x00}, m.IdentityHint...) - binary.BigEndian.PutUint16(out, uint16(len(out)-2)) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec // G115 } if m.PublicKey != nil { @@ -48,7 +48,7 @@ func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageClientKeyExchange) Unmarshal(data []byte) error { switch { case len(data) < 2: diff --git a/pkg/protocol/handshake/message_finished.go b/pkg/protocol/handshake/message_finished.go index 255aedd7e..f7187d88a 100644 --- a/pkg/protocol/handshake/message_finished.go +++ b/pkg/protocol/handshake/message_finished.go @@ -13,18 +13,19 @@ type MessageFinished struct { VerifyData []byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageFinished) Type() Type { return TypeFinished } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageFinished) Marshal() ([]byte, error) { return append([]byte{}, m.VerifyData...), nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageFinished) Unmarshal(data []byte) error { m.VerifyData = append([]byte{}, data...) + return nil } diff --git a/pkg/protocol/handshake/message_hello_verify_request.go b/pkg/protocol/handshake/message_hello_verify_request.go index 98960400c..7f5bc95aa 100644 --- a/pkg/protocol/handshake/message_hello_verify_request.go +++ b/pkg/protocol/handshake/message_hello_verify_request.go @@ -27,12 +27,12 @@ type MessageHelloVerifyRequest struct { Cookie []byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageHelloVerifyRequest) Type() Type { return TypeHelloVerifyRequest } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) { if len(m.Cookie) > 255 { return nil, errCookieTooLong @@ -47,7 +47,7 @@ func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) { return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error { if len(data) < 3 { return errBufferTooSmall @@ -61,5 +61,6 @@ func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error { m.Cookie = make([]byte, cookieLength) copy(m.Cookie, data[3:3+cookieLength]) + return nil } diff --git a/pkg/protocol/handshake/message_hello_verify_request_test.go b/pkg/protocol/handshake/message_hello_verify_request_test.go index 82252566e..0cfd24b70 100644 --- a/pkg/protocol/handshake/message_hello_verify_request_test.go +++ b/pkg/protocol/handshake/message_hello_verify_request_test.go @@ -17,7 +17,10 @@ func TestHandshakeMessageHelloVerifyRequest(t *testing.T) { } parsedHelloVerifyRequest := &MessageHelloVerifyRequest{ Version: protocol.Version{Major: 0xFE, Minor: 0xFF}, - Cookie: []byte{0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd}, + Cookie: []byte{ + 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, + 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd, + }, } h := &MessageHelloVerifyRequest{} diff --git a/pkg/protocol/handshake/message_server_hello.go b/pkg/protocol/handshake/message_server_hello.go index a1e86e1b2..e2f19c1d9 100644 --- a/pkg/protocol/handshake/message_server_hello.go +++ b/pkg/protocol/handshake/message_server_hello.go @@ -29,12 +29,12 @@ type MessageServerHello struct { const messageServerHelloVariableWidthStart = 2 + RandomLength -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerHello) Type() Type { return TypeServerHello } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageServerHello) Marshal() ([]byte, error) { if m.CipherSuiteID == nil { return nil, errCipherSuiteUnset @@ -49,23 +49,23 @@ func (m *MessageServerHello) Marshal() ([]byte, error) { rand := m.Random.MarshalFixed() copy(out[2:], rand[:]) - out = append(out, byte(len(m.SessionID))) - out = append(out, m.SessionID...) + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], *m.CipherSuiteID) - out = append(out, byte(m.CompressionMethod.ID)) + out = append(out, byte(m.CompressionMethod.ID)) //nolint:makezero // todo: fix extensions, err := extension.Marshal(m.Extensions) if err != nil { return nil, err } - return append(out, extensions...), nil + return append(out, extensions...), nil //nolint:makezero // todo: fix } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageServerHello) Unmarshal(data []byte) error { if len(data) < 2+RandomLength { return errBufferTooSmall @@ -110,6 +110,7 @@ func (m *MessageServerHello) Unmarshal(data []byte) error { if len(data) <= currOffset { m.Extensions = []extension.Extension{} + return nil } @@ -118,5 +119,6 @@ func (m *MessageServerHello) Unmarshal(data []byte) error { return err } m.Extensions = extensions + return nil } diff --git a/pkg/protocol/handshake/message_server_hello_done.go b/pkg/protocol/handshake/message_server_hello_done.go index b187dd417..49a830c2a 100644 --- a/pkg/protocol/handshake/message_server_hello_done.go +++ b/pkg/protocol/handshake/message_server_hello_done.go @@ -5,20 +5,20 @@ package handshake // MessageServerHelloDone is final non-encrypted message from server // this communicates server has sent all its handshake messages and next -// should be MessageFinished +// should be MessageFinished. type MessageServerHelloDone struct{} -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerHelloDone) Type() Type { return TypeServerHelloDone } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageServerHelloDone) Marshal() ([]byte, error) { return []byte{}, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageServerHelloDone) Unmarshal([]byte) error { return nil } diff --git a/pkg/protocol/handshake/message_server_hello_test.go b/pkg/protocol/handshake/message_server_hello_test.go index c4d6683a5..810dcd348 100644 --- a/pkg/protocol/handshake/message_server_hello_test.go +++ b/pkg/protocol/handshake/message_server_hello_test.go @@ -27,7 +27,10 @@ func TestHandshakeMessageServerHello(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: Random{ GMTUnixTime: time.Unix(560149025, 0), - RandomBytes: [28]byte{0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28}, + RandomBytes: [28]byte{ + 0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, + 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28, + }, }, SessionID: []byte{}, CipherSuiteID: &cipherSuiteID, diff --git a/pkg/protocol/handshake/message_server_key_exchange.go b/pkg/protocol/handshake/message_server_key_exchange.go index 1edac45f9..59a5392f3 100644 --- a/pkg/protocol/handshake/message_server_key_exchange.go +++ b/pkg/protocol/handshake/message_server_key_exchange.go @@ -12,7 +12,7 @@ import ( "github.com/pion/dtls/v3/pkg/crypto/signature" ) -// MessageServerKeyExchange supports ECDH and PSK +// MessageServerKeyExchange supports ECDH and PSK. type MessageServerKeyExchange struct { IdentityHint []byte @@ -27,17 +27,17 @@ type MessageServerKeyExchange struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerKeyExchange) Type() Type { return TypeServerKeyExchange } -// Marshal encodes the Handshake -func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { +// Marshal encodes the Handshake. +func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { //nolint:cyclop var out []byte if m.IdentityHint != nil { out = append([]byte{0x00, 0x00}, m.IdentityHint...) - binary.BigEndian.PutUint16(out, uint16(len(out)-2)) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec //G115 } if m.EllipticCurveType == 0 || len(m.PublicKey) == 0 { @@ -60,14 +60,14 @@ func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { } out = append(out, []byte{byte(m.HashAlgorithm), byte(m.SignatureAlgorithm), 0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature))) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature))) //nolint:gosec // G115 out = append(out, m.Signature...) return out, nil } -// Unmarshal populates the message from encoded data -func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { //nolint:cyclop switch { case len(data) < 2: return errBufferTooSmall @@ -84,6 +84,7 @@ func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { if len(data) == 0 { return nil } + return errLengthMismatch } @@ -144,5 +145,6 @@ func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { return errBufferTooSmall } m.Signature = append([]byte{}, data[offset:offset+signatureLength]...) + return nil } diff --git a/pkg/protocol/handshake/random.go b/pkg/protocol/handshake/random.go index 56f37569b..6eb2815f4 100644 --- a/pkg/protocol/handshake/random.go +++ b/pkg/protocol/handshake/random.go @@ -9,7 +9,7 @@ import ( "time" ) -// Consts for Random in Handshake +// Consts for Random in Handshake. const ( RandomBytesLength = 28 RandomLength = RandomBytesLength + 4 @@ -23,24 +23,24 @@ type Random struct { RandomBytes [RandomBytesLength]byte } -// MarshalFixed encodes the Handshake +// MarshalFixed encodes the Handshake. func (r *Random) MarshalFixed() [RandomLength]byte { var out [RandomLength]byte - binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix())) + binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix())) //nolint:gosec // G115 copy(out[4:], r.RandomBytes[:]) return out } -// UnmarshalFixed populates the message from encoded data +// UnmarshalFixed populates the message from encoded data. func (r *Random) UnmarshalFixed(data [RandomLength]byte) { r.GMTUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0) copy(r.RandomBytes[:], data[4:]) } // Populate fills the handshakeRandom with random values -// may be called multiple times +// may be called multiple times. func (r *Random) Populate() error { r.GMTUnixTime = time.Now() diff --git a/pkg/protocol/recordlayer/errors.go b/pkg/protocol/recordlayer/errors.go index ba2b396b8..09599249b 100644 --- a/pkg/protocol/recordlayer/errors.go +++ b/pkg/protocol/recordlayer/errors.go @@ -11,8 +11,11 @@ import ( ) var ( - // ErrInvalidPacketLength is returned when the packet length too small or declared length do not match - ErrInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 + // ErrInvalidPacketLength is returned when the packet length too small + // or declared length do not match. + ErrInvalidPacketLength = &protocol.TemporaryError{ + Err: errors.New("packet length and declared length do not match"), //nolint:goerr113 + } errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 diff --git a/pkg/protocol/recordlayer/header.go b/pkg/protocol/recordlayer/header.go index 0899d5a22..47af855df 100644 --- a/pkg/protocol/recordlayer/header.go +++ b/pkg/protocol/recordlayer/header.go @@ -10,7 +10,7 @@ import ( "github.com/pion/dtls/v3/pkg/protocol" ) -// Header implements a TLS RecordLayer header +// Header implements a TLS RecordLayer header. type Header struct { ContentType protocol.ContentType ContentLen uint16 @@ -22,7 +22,7 @@ type Header struct { ConnectionID []byte } -// RecordLayer enums +// RecordLayer enums. const ( // FixedHeaderSize is the size of a DTLS record header when connection IDs // are not in use. @@ -30,7 +30,7 @@ const ( MaxSequenceNumber = 0x0000FFFFFFFFFFFF ) -// Marshal encodes a TLS RecordLayer Header to binary +// Marshal encodes a TLS RecordLayer Header to binary. func (h *Header) Marshal() ([]byte, error) { if h.SequenceNumber > MaxSequenceNumber { return nil, errSequenceNumberOverflow @@ -46,10 +46,11 @@ func (h *Header) Marshal() ([]byte, error) { util.PutBigEndianUint48(out[5:], h.SequenceNumber) copy(out[11:11+len(h.ConnectionID)], h.ConnectionID) binary.BigEndian.PutUint16(out[hs-2:], h.ContentLen) + return out, nil } -// Unmarshal populates a TLS RecordLayer Header from binary +// Unmarshal populates a TLS RecordLayer Header from binary. func (h *Header) Unmarshal(data []byte) error { if len(data) < FixedHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/recordlayer/inner_plaintext.go b/pkg/protocol/recordlayer/inner_plaintext.go index 296475a6c..2c67c86f2 100644 --- a/pkg/protocol/recordlayer/inner_plaintext.go +++ b/pkg/protocol/recordlayer/inner_plaintext.go @@ -17,22 +17,24 @@ type InnerPlaintext struct { Zeros uint } -// Marshal encodes a DTLS InnerPlaintext to binary +// Marshal encodes a DTLS InnerPlaintext to binary. func (p *InnerPlaintext) Marshal() ([]byte, error) { var out cryptobyte.Builder out.AddBytes(p.Content) out.AddUint8(uint8(p.RealType)) out.AddBytes(make([]byte, p.Zeros)) + return out.Bytes() } -// Unmarshal populates a DTLS InnerPlaintext from binary +// Unmarshal populates a DTLS InnerPlaintext from binary. func (p *InnerPlaintext) Unmarshal(data []byte) error { // Process in reverse i := len(data) - 1 for i >= 0 { if data[i] != 0 { - p.Zeros = uint(len(data) - 1 - i) + p.Zeros = uint(len(data) - 1 - i) //nolint:gosec // G115 + break } i-- diff --git a/pkg/protocol/recordlayer/recordlayer.go b/pkg/protocol/recordlayer/recordlayer.go index 61e25a563..95113da4f 100644 --- a/pkg/protocol/recordlayer/recordlayer.go +++ b/pkg/protocol/recordlayer/recordlayer.go @@ -48,14 +48,14 @@ type RecordLayer struct { Content protocol.Content } -// Marshal encodes the RecordLayer to binary +// Marshal encodes the RecordLayer to binary. func (r *RecordLayer) Marshal() ([]byte, error) { contentRaw, err := r.Content.Marshal() if err != nil { return nil, err } - r.Header.ContentLen = uint16(len(contentRaw)) + r.Header.ContentLen = uint16(len(contentRaw)) //nolint:gosec // G115 r.Header.ContentType = r.Content.ContentType() headerRaw, err := r.Header.Marshal() @@ -66,7 +66,7 @@ func (r *RecordLayer) Marshal() ([]byte, error) { return append(headerRaw, contentRaw...), nil } -// Unmarshal populates the RecordLayer from binary +// Unmarshal populates the RecordLayer from binary. func (r *RecordLayer) Unmarshal(data []byte) error { if err := r.Header.Unmarshal(data); err != nil { return err diff --git a/pkg/protocol/version.go b/pkg/protocol/version.go index c4d94ac3a..3943c1504 100644 --- a/pkg/protocol/version.go +++ b/pkg/protocol/version.go @@ -4,7 +4,7 @@ // Package protocol provides the DTLS wire format package protocol -// Version enums +// Version enums. var ( Version1_0 = Version{Major: 0xfe, Minor: 0xff} //nolint:gochecknoglobals Version1_2 = Version{Major: 0xfe, Minor: 0xfd} //nolint:gochecknoglobals @@ -18,7 +18,7 @@ type Version struct { Major, Minor uint8 } -// Equal determines if two protocol versions are equal +// Equal determines if two protocol versions are equal. func (v Version) Equal(x Version) bool { return v.Major == x.Major && v.Minor == x.Minor } diff --git a/replayprotection_test.go b/replayprotection_test.go index 725e05239..e0984ef62 100644 --- a/replayprotection_test.go +++ b/replayprotection_test.go @@ -16,7 +16,7 @@ import ( "github.com/pion/transport/v3/test" ) -func TestReplayProtection(t *testing.T) { +func TestReplayProtection(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() @@ -52,6 +52,7 @@ func TestReplayProtection(t *testing.T) { } if _, werr := cb.Write(b[:n]); werr != nil { t.Error(werr) + return } @@ -109,10 +110,12 @@ func TestReplayProtection(t *testing.T) { sent = append(sent, data) if _, werr := ca.Write(data); werr != nil { t.Error(werr) + return } if _, werr := cb.Write(data); werr != nil { t.Error(werr) + return } } diff --git a/resume.go b/resume.go index 0b76314a5..954907dd0 100644 --- a/resume.go +++ b/resume.go @@ -7,10 +7,11 @@ import ( "net" ) -// Resume imports an already established dtls connection using a specific dtls state +// Resume imports an already established dtls connection using a specific dtls state. func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if err := state.initCipherSuite(); err != nil { return nil, err } + return createConn(conn, rAddr, config, state.isClient, state) } diff --git a/resume_test.go b/resume_test.go index 4f79adb3f..a3ce26cec 100644 --- a/resume_test.go +++ b/resume_test.go @@ -32,11 +32,20 @@ func TestResumeServer(t *testing.T) { } func fatal(t *testing.T, errChan chan error, err error) { + t.Helper() + close(errChan) t.Fatal(err) } -func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Addr, *Config) (*Conn, error)) { +//nolint:cyclop +func DoTestResume( + t *testing.T, + newLocal, + newRemote func(net.PacketConn, net.Addr, *Config) (*Conn, error), +) { + t.Helper() + // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -176,8 +185,10 @@ func (b *backupConn) Read(data []byte) (n int, err error) { b.curr = b.next b.next = nil b.mux.Unlock() + return b.Read(data) } + return n, err } @@ -188,8 +199,10 @@ func (b *backupConn) Write(data []byte) (n int, err error) { b.curr = b.next b.next = nil b.mux.Unlock() + return b.Write(data) } + return n, err } diff --git a/session.go b/session.go index 99bf5a499..912a5997a 100644 --- a/session.go +++ b/session.go @@ -3,7 +3,7 @@ package dtls -// Session store data needed in resumption +// Session store data needed in resumption. type Session struct { // ID store session id ID []byte diff --git a/state.go b/state.go index f1afb857d..364f7dcc1 100644 --- a/state.go +++ b/state.go @@ -16,7 +16,8 @@ import ( "github.com/pion/transport/v3/replaydetector" ) -// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// State holds the dtls connection state and implements both encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. type State struct { localEpoch, remoteEpoch atomic.Value localSequenceNumber []uint64 // uint48 @@ -112,6 +113,7 @@ func (s *State) serialize() (*serializedState, error) { remoteRnd := s.remoteRandom.MarshalFixed() epoch := s.getLocalEpoch() + return &serializedState{ LocalEpoch: s.getLocalEpoch(), RemoteEpoch: s.getRemoteEpoch(), @@ -193,10 +195,11 @@ func (s *State) initCipherSuite() error { if err != nil { return err } + return nil } -// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation +// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation. func (s *State) MarshalBinary() ([]byte, error) { serialized, err := s.serialize() if err != nil { @@ -208,10 +211,11 @@ func (s *State) MarshalBinary() ([]byte, error) { if err := enc.Encode(*serialized); err != nil { return nil, err } + return buf.Bytes(), nil } -// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation +// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation. func (s *State) UnmarshalBinary(data []byte) error { enc := gob.NewDecoder(bytes.NewBuffer(data)) var serialized serializedState @@ -227,7 +231,7 @@ func (s *State) UnmarshalBinary(data []byte) error { // ExportKeyingMaterial returns length bytes of exported key material in a new // slice as defined in RFC 5705. // This allows protocols to use DTLS for key establishment, but -// then use some of the keying material for their own purposes +// then use some of the keying material for their own purposes. func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { if s.getLocalEpoch() == 0 { return nil, errHandshakeInProgress @@ -246,6 +250,7 @@ func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ( } else { seed = append(append(seed, remoteRandom[:]...), localRandom[:]...) } + return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc()) } @@ -253,6 +258,7 @@ func (s *State) getRemoteEpoch() uint16 { if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok { return remoteEpoch } + return 0 } @@ -260,6 +266,7 @@ func (s *State) getLocalEpoch() uint16 { if localEpoch, ok := s.localEpoch.Load().(uint16); ok { return localEpoch } + return 0 } @@ -287,7 +294,7 @@ func (s *State) setLocalConnectionID(v []byte) { s.localConnectionID.Store(v) } -// RemoteRandomBytes returns the remote client hello random bytes +// RemoteRandomBytes returns the remote client hello random bytes. func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte { return s.remoteRandom.RandomBytes } diff --git a/util.go b/util.go index 663c4437c..3d9b0bc85 100644 --- a/util.go +++ b/util.go @@ -11,6 +11,7 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil } } } + return 0, false } @@ -22,6 +23,7 @@ func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { } } } + return nil, false }