From f19300581933c51934f0e0de8792b85f649c7652 Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Sat, 25 Jan 2025 05:33:22 -0600 Subject: [PATCH] Upgrade golangci-lint, more linters Introduces new linters, upgrade golangci-lint to version (v1.63.4) --- .golangci.yml | 47 +- ack_timer.go | 10 +- association.go | 655 ++++++++++++--------- association_test.go | 304 ++++++---- chunk_abort.go | 6 +- chunk_cookie_ack.go | 5 +- chunk_cookie_echo.go | 3 +- chunk_error.go | 6 +- chunk_forward_tsn.go | 8 +- chunk_heartbeat.go | 3 +- chunk_heartbeat_ack.go | 3 +- chunk_init.go | 5 +- chunk_init_ack.go | 9 +- chunk_init_common.go | 7 +- chunk_init_test.go | 14 +- chunk_payload_data.go | 9 +- chunk_reconfig.go | 6 +- chunk_reconfig_test.go | 29 +- chunk_selective_ack.go | 11 +- chunk_shutdown.go | 7 +- chunk_shutdown_ack.go | 7 +- chunk_shutdown_complete.go | 7 +- chunk_test.go | 170 ++++-- chunkheader.go | 13 +- chunktype.go | 4 +- control_queue.go | 1 + error_cause.go | 22 +- error_cause_header.go | 9 +- error_cause_invalid_mandatory_parameter.go | 4 +- error_cause_protocol_violation.go | 5 +- error_cause_unrecognized_chunk_type.go | 6 +- error_cause_user_initiated_abort.go | 4 +- examples/ping-pong/ping/conn.go | 18 +- examples/ping-pong/ping/main.go | 2 +- examples/ping-pong/pong/conn.go | 18 +- examples/ping-pong/pong/main.go | 2 +- packet.go | 52 +- packet_test.go | 29 +- param.go | 6 +- param_ecn_capable.go | 2 + param_forward_tsn_supported.go | 2 + param_heartbeat_info.go | 2 + param_outgoing_reset_request.go | 3 +- param_outgoing_reset_request_test.go | 5 +- param_random.go | 2 + param_reconfig_response.go | 2 +- param_state_cookie.go | 4 +- param_zero_checksum.go | 4 +- paramheader.go | 16 +- paramtype.go | 7 +- payload_queue.go | 2 + payload_queue_test.go | 4 +- pending_queue.go | 25 +- pending_queue_test.go | 51 +- queue.go | 1 + queue_test.go | 40 +- reassembly_queue.go | 39 +- reassembly_queue_test.go | 2 +- receive_payload_queue.go | 43 +- receive_payload_queue_test.go | 123 ++-- rtx_timer.go | 29 +- rtx_timer_test.go | 2 +- stream.go | 64 +- stream_test.go | 32 +- util.go | 3 +- vnet_test.go | 21 +- 66 files changed, 1254 insertions(+), 802 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index a3235bec..88cb4fbf 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,17 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -46,18 +61,17 @@ linters: - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - err113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -65,9 +79,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -75,28 +95,22 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -104,8 +118,7 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! diff --git a/ack_timer.go b/ack_timer.go index b6008d18..80e96885 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -26,7 +26,7 @@ const ( ackTimerClosed ) -// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 +// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. type ackTimer struct { timer *time.Timer observer ackTimerObserver @@ -40,6 +40,7 @@ func newAckTimer(observer ackTimerObserver) *ackTimer { t := &ackTimer{observer: observer} t.timer = time.AfterFunc(math.MaxInt64, t.timeout) t.timer.Stop() + return t } @@ -65,11 +66,12 @@ func (t *ackTimer) start() bool { t.state = ackTimerStarted t.pending++ t.timer.Reset(ackInterval) + return true } // stops the timer. this is similar to stop() but subsequent start() call -// will fail (the timer is no longer usable) +// will fail (the timer is no longer usable). func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() @@ -83,7 +85,7 @@ func (t *ackTimer) stop() { } // closes the timer. this is similar to stop() but subsequent start() call -// will fail (the timer is no longer usable) +// will fail (the timer is no longer usable). func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() @@ -95,7 +97,7 @@ func (t *ackTimer) close() { } // isRunning tests if the timer is running. -// Debug purpose only +// Debug purpose only. func (t *ackTimer) isRunning() bool { t.mutex.Lock() defer t.mutex.Unlock() diff --git a/association.go b/association.go index 3a85e575..2992724b 100644 --- a/association.go +++ b/association.go @@ -29,7 +29,7 @@ const defaultSCTPSrcDstPort = 5000 // Use global random generator to properly seed by crypto grade random. var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals -// Association errors +// Association errors. var ( ErrChunk = errors.New("abort chunk, with following errors") ErrShutdownNonEstablished = errors.New("shutdown called in non-established state") @@ -41,18 +41,20 @@ var ( ErrSCTPPacketSourcePortZero = errors.New("sctp packet must not have a source port of 0") ErrSCTPPacketDestinationPortZero = errors.New("sctp packet must not have a destination port of 0") ErrInitChunkBundled = errors.New("init chunk must not be bundled with any other chunk") - ErrInitChunkVerifyTagNotZero = errors.New("init chunk expects a verification tag of 0 on the packet when out-of-the-blue") - ErrHandleInitState = errors.New("todo: handle Init when in state") - ErrInitAckNoCookie = errors.New("no cookie in InitAck") - ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN") - ErrTSNRequestNotExist = errors.New("requested non-existent TSN") - ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state") - ErrParamterType = errors.New("unexpected parameter type") - ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state") - ErrChunkTypeUnhandled = errors.New("unhandled chunk type") - ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)") - ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)") - ErrTooManyReconfigRequests = errors.New("too many outstanding reconfig requests") + ErrInitChunkVerifyTagNotZero = errors.New( + "init chunk expects a verification tag of 0 on the packet when out-of-the-blue", + ) + ErrHandleInitState = errors.New("todo: handle Init when in state") + ErrInitAckNoCookie = errors.New("no cookie in InitAck") + ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN") + ErrTSNRequestNotExist = errors.New("requested non-existent TSN") + ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state") + ErrParamterType = errors.New("unexpected parameter type") + ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state") + ErrChunkTypeUnhandled = errors.New("unhandled chunk type") + ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)") + ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)") + ErrTooManyReconfigRequests = errors.New("too many outstanding reconfig requests") ) const ( @@ -64,7 +66,7 @@ const ( defaultMaxMessageSize uint32 = 65536 ) -// association state enums +// association state enums. const ( closed uint32 = iota cookieWait @@ -76,7 +78,7 @@ const ( shutdownSent ) -// retransmission timer IDs +// retransmission timer IDs. const ( timerT1Init int = iota timerT1Cookie @@ -85,21 +87,21 @@ const ( timerReconfig ) -// ack mode (for testing) +// ack mode (for testing). const ( ackModeNormal int = iota ackModeNoDelay ackModeAlwaysDelay ) -// ack transmission state +// ack transmission state. const ( ackStateIdle int = iota // ack timer is off ackStateImmediate // will send ack immediately ackStateDelay // ack timer is on (ack is being delayed) ) -// other constants +// other constants. const ( acceptChSize = 16 // avgChunkSize is an estimate of the average chunk size. There is no theory behind @@ -107,18 +109,18 @@ const ( avgChunkSize = 500 // minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size - // see getMaxTSNOffset + // see getMaxTSNOffset. minTSNOffset = 2000 // maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size - // see getMaxTSNOffset + // see getMaxTSNOffset. maxTSNOffset = 40000 - // maxReconfigRequests is the maximum number of reconfig requests we will keep outstanding + // maxReconfigRequests is the maximum number of reconfig requests we will keep outstanding. maxReconfigRequests = 1000 ) -func getAssociationStateString(a uint32) string { - switch a { +func getAssociationStateString(assoc uint32) string { + switch assoc { case closed: return "Closed" case cookieWait: @@ -136,7 +138,7 @@ func getAssociationStateString(a uint32) string { case shutdownAckSent: return "ShutdownAckSent" default: - return fmt.Sprintf("Invalid association state %d", a) + return fmt.Sprintf("Invalid association state %d", assoc) } } @@ -261,7 +263,7 @@ type Association struct { } // Config collects the arguments to createAssociation construction into -// a single structure +// a single structure. type Config struct { Name string NetConn net.Conn @@ -282,7 +284,7 @@ type Config struct { CwndCAStep uint32 } -// Server accepts a SCTP stream over a conn +// Server accepts a SCTP stream over a conn. func Server(config Config) (*Association, error) { a := createAssociation(config) a.init(false) @@ -292,32 +294,35 @@ func Server(config Config) (*Association, error) { if err != nil { return nil, err } + return a, nil case <-a.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } -// Client opens a SCTP stream over a conn +// Client opens a SCTP stream over a conn. func Client(config Config) (*Association, error) { return createClientWithContext(context.Background(), config) } func createClientWithContext(ctx context.Context, config Config) (*Association, error) { - a := createAssociation(config) - a.init(true) + assoc := createAssociation(config) + assoc.init(true) select { case <-ctx.Done(): - a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState())) - a.Close() // nolint:errcheck,gosec + assoc.log.Errorf("[%s] client handshake canceled: state=%s", assoc.name, getAssociationStateString(assoc.getState())) + assoc.Close() // nolint:errcheck,gosec + return nil, ctx.Err() - case err := <-a.handshakeCompletedCh: + case err := <-assoc.handshakeCompletedCh: if err != nil { return nil, err } - return a, nil - case <-a.readLoopCloseCh: + + return assoc, nil + case <-assoc.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } @@ -338,7 +343,7 @@ func createAssociation(config Config) *Association { } tsn := globalMathRandomGenerator.Uint32() - a := &Association{ + assoc := &Association{ netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, @@ -385,27 +390,27 @@ func createAssociation(config Config) *Association { writeNotify: make(chan struct{}, 1), } - if a.name == "" { - a.name = fmt.Sprintf("%p", a) + if assoc.name == "" { + assoc.name = fmt.Sprintf("%p", assoc) } // RFC 4690 Sec 7.2.1 // o The initial cwnd before DATA transmission or after a sufficiently // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 // bytes)). - a.setCWND(min32(4*a.MTU(), max32(2*a.MTU(), 4380))) - a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", - a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) + assoc.setCWND(min32(4*assoc.MTU(), max32(2*assoc.MTU(), 4380))) + assoc.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", + assoc.name, assoc.CWND(), assoc.ssthresh, assoc.inflightQueue.getNumBytes()) - a.srtt.Store(float64(0)) - a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans, config.RTOMax) - a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans, config.RTOMax) - a.t2Shutdown = newRTXTimer(timerT2Shutdown, a, noMaxRetrans, config.RTOMax) - a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans, config.RTOMax) - a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans, config.RTOMax) - a.ackTimer = newAckTimer(a) + assoc.srtt.Store(float64(0)) + assoc.t1Init = newRTXTimer(timerT1Init, assoc, maxInitRetrans, config.RTOMax) + assoc.t1Cookie = newRTXTimer(timerT1Cookie, assoc, maxInitRetrans, config.RTOMax) + assoc.t2Shutdown = newRTXTimer(timerT2Shutdown, assoc, noMaxRetrans, config.RTOMax) + assoc.t3RTX = newRTXTimer(timerT3RTX, assoc, noMaxRetrans, config.RTOMax) + assoc.tReconfig = newRTXTimer(timerReconfig, assoc, noMaxRetrans, config.RTOMax) + assoc.ackTimer = newAckTimer(assoc) - return a + return assoc } func (a *Association) init(isClient bool) { @@ -444,7 +449,7 @@ func (a *Association) init(isClient bool) { } } -// caller must hold a.lock +// caller must hold a.lock. func (a *Association) sendInit() error { a.log.Debugf("[%s] sending INIT", a.name) if a.storedInit == nil { @@ -466,7 +471,7 @@ func (a *Association) sendInit() error { return nil } -// caller must hold a.lock +// caller must hold a.lock. func (a *Association) sendCookieEcho() error { if a.storedCookieEcho == nil { return ErrCookieEchoNotStoredToSend @@ -520,7 +525,7 @@ func (a *Association) Shutdown(ctx context.Context) error { } } -// Close ends the SCTP Association and cleans up any state +// Close ends the SCTP Association and cleans up any state. func (a *Association) Close() error { a.log.Debugf("[%s] closing association..", a.name) @@ -618,6 +623,7 @@ func (a *Association) readLoop() { n, err := a.netConn.Read(buffer) if err != nil { closeErr = err + break } // Make a buffer sized to what we read, then copy the data we @@ -626,9 +632,10 @@ func (a *Association) readLoop() { // copying. inbound := make([]byte, n) copy(inbound, buffer[:n]) - atomic.AddUint64(&a.bytesReceived, uint64(n)) + atomic.AddUint64(&a.bytesReceived, uint64(n)) //nolint:gosec // G115 if err = a.handleInbound(inbound); err != nil { closeErr = err + break } } @@ -651,6 +658,7 @@ loop: a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err) } a.log.Debugf("[%s] writeLoop ended", a.name) + break loop } atomic.AddUint64(&a.bytesSent, uint64(len(raw))) @@ -715,6 +723,7 @@ func chunkMandatoryChecksum(cc []chunk) bool { return true } } + return false } @@ -727,26 +736,29 @@ func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { if err := p.unmarshal(!a.recvZeroChecksum, raw); err != nil { return nil, err } + return p, nil } -// handleInbound parses incoming raw packets +// handleInbound parses incoming raw packets. func (a *Association) handleInbound(raw []byte) error { - p, err := a.unmarshalPacket(raw) + pkt, err := a.unmarshalPacket(raw) if err != nil { a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) + return nil } - if err := checkPacket(p); err != nil { + if err := checkPacket(pkt); err != nil { a.log.Warnf("[%s] failed validating packet %s", a.name, err) + return nil } a.handleChunksStart() - for _, c := range p.chunks { - if err := a.handleChunk(p, c); err != nil { + for _, c := range pkt.chunks { + if err := a.handleChunk(pkt, c); err != nil { return err } } @@ -756,12 +768,13 @@ func (a *Association) handleInbound(raw []byte) error { return nil } -// The caller should hold the lock +// The caller should hold the lock. func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte { for _, p := range a.getDataPacketsToRetransmit() { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) + continue } rawPackets = append(rawPackets, raw) @@ -770,7 +783,9 @@ func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byt return rawPackets } -// The caller should hold the lock +// The caller should hold the lock. +// +//nolint:cyclop func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) [][]byte { // Pop unsent data chunks from the pending queue to send as much as // cwnd and rwnd allow. @@ -783,13 +798,14 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) + continue } rawPackets = append(rawPackets, raw) } } - if len(sisToReset) > 0 || a.willRetransmitReconfig { + if len(sisToReset) > 0 || a.willRetransmitReconfig { //nolint:nestif if a.willRetransmitReconfig { a.willRetransmitReconfig = false a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) @@ -834,9 +850,11 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) return rawPackets } -// The caller should hold the lock +// The caller should hold the lock. +// +//nolint:cyclop func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byte) [][]byte { - if a.willRetransmitFast { + if a.willRetransmitFast { //nolint:nestif a.willRetransmitFast = false toFastRetrans := []*chunkPayloadData{} @@ -847,16 +865,16 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt fastRetransWnd = a.fastRtxWnd } for i := 0; ; i++ { - c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) + chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 if !ok { break // end of pending data } - if c.acked || c.abandoned() { + if chunkPayload.acked || chunkPayload.abandoned() { continue } - if c.nSent > 1 || c.missIndicator < 3 { + if chunkPayload.nSent > 1 || chunkPayload.missIndicator < 3 { continue } @@ -870,18 +888,18 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // of cwnd and SHOULD NOT delay retransmission for this single // packet. - dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) + dataChunkSize := dataChunkHeaderSize + uint32(len(chunkPayload.userData)) //nolint:gosec // G115 if fastRetransWnd < fastRetransSize+dataChunkSize { break } fastRetransSize += dataChunkSize a.stats.incFastRetrans() - c.nSent++ - a.checkPartialReliabilityStatus(c) - toFastRetrans = append(toFastRetrans, c) + chunkPayload.nSent++ + a.checkPartialReliabilityStatus(chunkPayload) + toFastRetrans = append(toFastRetrans, chunkPayload) a.log.Tracef("[%s] fast-retransmit: tsn=%d sent=%d htna=%d", - a.name, c.tsn, c.nSent, a.fastRecoverExitPoint) + a.name, chunkPayload.tsn, chunkPayload.nSent, a.fastRecoverExitPoint) } if len(toFastRetrans) > 0 { @@ -889,6 +907,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + continue } rawPackets = append(rawPackets, raw) @@ -899,7 +918,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt return rawPackets } -// The caller should hold the lock +// The caller should hold the lock. func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { if a.ackState == ackStateImmediate { a.ackState = ackStateIdle @@ -917,7 +936,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { return rawPackets } -// The caller should hold the lock +// The caller should hold the lock. func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]byte { if a.willSendForwardTSN { a.willSendForwardTSN = false @@ -1009,6 +1028,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) { pkt, err := a.gatherAbortPacket() if err != nil { a.log.Warnf("[%s] failed to serialize an abort packet", a.name) + return nil, false } @@ -1022,6 +1042,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a control packet", a.name) + continue } rawPackets = append(rawPackets, raw) @@ -1051,7 +1072,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) { return rawPackets, ok } -func checkPacket(p *packet) error { +func checkPacket(pkt *packet) error { // All packets must adhere to these rules // This is the SCTP sender's port number. It can be used by the @@ -1059,7 +1080,7 @@ func checkPacket(p *packet) error { // destination port, and possibly the destination IP address to // identify the association to which this packet belongs. The port // number 0 MUST NOT be used. - if p.sourcePort == 0 { + if pkt.sourcePort == 0 { return ErrSCTPPacketSourcePortZero } @@ -1067,24 +1088,24 @@ func checkPacket(p *packet) error { // The receiving host will use this port number to de-multiplex the // SCTP packet to the correct receiving endpoint/application. The // port number 0 MUST NOT be used. - if p.destinationPort == 0 { + if pkt.destinationPort == 0 { return ErrSCTPPacketDestinationPortZero } // Check values on the packet that are specific to a particular chunk type - for _, c := range p.chunks { + for _, c := range pkt.chunks { switch c.(type) { // nolint:gocritic case *chunkInit: // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. // They MUST be the only chunks present in the SCTP packets that carry // them. - if len(p.chunks) != 1 { + if len(pkt.chunks) != 1 { return ErrInitChunkBundled } // A packet containing an INIT chunk MUST have a zero Verification // Tag. - if p.verificationTag != 0 { + if pkt.verificationTag != 0 { return ErrInitChunkVerifyTagNotZero } } @@ -1097,6 +1118,7 @@ func min16(a, b uint16) uint16 { if a < b { return a } + return b } @@ -1104,6 +1126,7 @@ func max32(a, b uint32) uint32 { if a > b { return a } + return b } @@ -1111,10 +1134,11 @@ func min32(a, b uint32) uint32 { if a < b { return a } + return b } -// peerLastTSN return last received cumulative TSN +// peerLastTSN return last received cumulative TSN. func (a *Association) peerLastTSN() uint32 { return a.payloadQueue.getcumulativeTSN() } @@ -1136,22 +1160,22 @@ func (a *Association) getState() uint32 { return atomic.LoadUint32(&a.state) } -// BytesSent returns the number of bytes sent +// BytesSent returns the number of bytes sent. func (a *Association) BytesSent() uint64 { return atomic.LoadUint64(&a.bytesSent) } -// BytesReceived returns the number of bytes received +// BytesReceived returns the number of bytes received. func (a *Association) BytesReceived() uint64 { return atomic.LoadUint64(&a.bytesReceived) } -// MTU returns the association's current MTU +// MTU returns the association's current MTU. func (a *Association) MTU() uint32 { return atomic.LoadUint32(&a.mtu) } -// CWND returns the association's current congestion window (cwnd) +// CWND returns the association's current congestion window (cwnd). func (a *Association) CWND() uint32 { return atomic.LoadUint32(&a.cwnd) } @@ -1163,7 +1187,7 @@ func (a *Association) setCWND(cwnd uint32) { atomic.StoreUint32(&a.cwnd, cwnd) } -// RWND returns the association's current receiver window (rwnd) +// RWND returns the association's current receiver window (rwnd). func (a *Association) RWND() uint32 { return atomic.LoadUint32(&a.rwnd) } @@ -1172,7 +1196,7 @@ func (a *Association) setRWND(rwnd uint32) { atomic.StoreUint32(&a.rwnd, rwnd) } -// SRTT returns the latest smoothed round-trip time (srrt) +// SRTT returns the latest smoothed round-trip time (srrt). func (a *Association) SRTT() float64 { return a.srtt.Load().(float64) //nolint:forcetypeassert } @@ -1189,6 +1213,7 @@ func getMaxTSNOffset(maxReceiveBufferSize uint32) uint32 { if offset > maxTSNOffset { offset = maxTSNOffset } + return offset } @@ -1204,7 +1229,9 @@ func setSupportedExtensions(init *chunkInitCommon) { } // The caller should hold the lock. -func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { +// +//nolint:cyclop +func (a *Association) handleInit(pkt *packet, initChunk *chunkInit) ([]*packet, error) { state := a.getState() a.log.Debugf("[%s] chunkInit received in state '%s'", a.name, getAssociationStateString(state)) @@ -1225,19 +1252,19 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // our cookie is not compliant with https://www.rfc-editor.org/rfc/rfc9260#section-5.1-2.2.3. // It makes us more vulnerable to resource attacks, albeit minimally so. // https://www.rfc-editor.org/rfc/rfc9260#sec_handle_stream_parameters - a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) - a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) - a.peerVerificationTag = i.initiateTag - a.sourcePort = p.destinationPort - a.destinationPort = p.sourcePort + a.myMaxNumInboundStreams = min16(initChunk.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(initChunk.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = initChunk.initiateTag + a.sourcePort = pkt.destinationPort + a.destinationPort = pkt.sourcePort // 13.2 This is the last TSN received in sequence. This value // is set initially by taking the peer's initial TSN, // received in the INIT or INIT ACK chunk, and // subtracting one from it. - a.payloadQueue.init(i.initialTSN - 1) + a.payloadQueue.init(initChunk.initialTSN - 1) - for _, param := range i.params { + for _, param := range initChunk.params { switch v := param.(type) { // nolint:gocritic case *paramSupportedExtensions: for _, t := range v.ChunkTypes { @@ -1293,7 +1320,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { } // The caller should hold the lock. -func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { +func (a *Association) handleInitAck(pkt *packet, initChunkAck *chunkInitAck) error { //nolint:cyclop state := a.getState() a.log.Debugf("[%s] chunkInitAck received in state '%s'", a.name, getAssociationStateString(state)) if state != cookieWait { @@ -1306,17 +1333,18 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { return nil } - a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) - a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) - a.peerVerificationTag = i.initiateTag - a.payloadQueue.init(i.initialTSN - 1) - if a.sourcePort != p.destinationPort || - a.destinationPort != p.sourcePort { + a.myMaxNumInboundStreams = min16(initChunkAck.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(initChunkAck.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = initChunkAck.initiateTag + a.payloadQueue.init(initChunkAck.initialTSN - 1) + if a.sourcePort != pkt.destinationPort || + a.destinationPort != pkt.sourcePort { a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) + return nil } - a.setRWND(i.advertisedReceiverWindowCredit) + a.setRWND(initChunkAck.advertisedReceiverWindowCredit) a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) // RFC 4690 Sec 7.2.1 @@ -1331,7 +1359,7 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { a.storedInit = nil var cookieParam *paramStateCookie - for _, param := range i.params { + for _, param := range initChunkAck.params { switch v := param.(type) { case *paramStateCookie: cookieParam = v @@ -1366,6 +1394,7 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { a.t1Cookie.start(a.rtoMgr.getRTO()) a.setState(cookieEchoed) + return nil } @@ -1392,23 +1421,24 @@ func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { } // The caller should hold the lock. -func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { +func (a *Association) handleCookieEcho(cookieEcho *chunkCookieEcho) []*packet { state := a.getState() a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) if a.myCookie == nil { a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name) + return nil } switch state { default: return nil case established: - if !bytes.Equal(a.myCookie.cookie, c.cookie) { + if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { return nil } case closed, cookieWait, cookieEchoed: - if !bytes.Equal(a.myCookie.cookie, c.cookie) { + if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { return nil } @@ -1432,6 +1462,7 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { destinationPort: a.destinationPort, chunks: []chunk{&chunkCookieAck{}}, } + return pack(p) } @@ -1455,44 +1486,51 @@ func (a *Association) handleCookieAck() { } // The caller should hold the lock. -func (a *Association) handleData(d *chunkPayloadData) []*packet { +func (a *Association) handleData(chunkPayload *chunkPayloadData) []*packet { a.log.Tracef("[%s] DATA: tsn=%d immediateSack=%v len=%d", - a.name, d.tsn, d.immediateSack, len(d.userData)) + a.name, chunkPayload.tsn, chunkPayload.immediateSack, len(chunkPayload.userData)) a.stats.incDATAs() - canPush := a.payloadQueue.canPush(d.tsn) - if canPush { - s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) - if s == nil { + canPush := a.payloadQueue.canPush(chunkPayload.tsn) + if canPush { //nolint:nestif + stream := a.getOrCreateStream(chunkPayload.streamIdentifier, true, PayloadTypeUnknown) + if stream == nil { // silently discard the data. (sender will retry on T3-rtx timeout) // see pion/sctp#30 - a.log.Debugf("[%s] discard %d", a.name, d.streamSequenceNumber) + a.log.Debugf("[%s] discard %d", a.name, chunkPayload.streamSequenceNumber) + return nil } if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives - a.payloadQueue.push(d.tsn) - s.handleData(d) + a.payloadQueue.push(chunkPayload.tsn) + stream.handleData(chunkPayload) } else { // Receive buffer is full lastTSN, ok := a.payloadQueue.getLastTSNReceived() - if ok && sna32LT(d.tsn, lastTSN) { - a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) - a.payloadQueue.push(d.tsn) - s.handleData(d) + if ok && sna32LT(chunkPayload.tsn, lastTSN) { + a.log.Debugf( + "[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, + ) + a.payloadQueue.push(chunkPayload.tsn) + stream.handleData(chunkPayload) } else { - a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) + a.log.Debugf( + "[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, + ) } } } - return a.handlePeerLastTSNAndAcknowledgement(d.immediateSack) + return a.handlePeerLastTSNAndAcknowledgement(chunkPayload.immediateSack) } // A common routine for handleData and handleForwardTSN routines // The caller should hold the lock. -func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { +func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { //nolint:cyclop var reply []*packet // Try to advance peerLastTSN @@ -1521,7 +1559,8 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString()) } - if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay { + if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || + a.ackMode == ackModeAlwaysDelay { if a.ackState == ackStateIdle { a.delayedAckTriggered = true } else { @@ -1538,17 +1577,21 @@ func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) func (a *Association) getMyReceiverWindowCredit() uint32 { var bytesQueued uint32 for _, s := range a.streams { - bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) + bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) //nolint:gosec // G115 } if bytesQueued >= a.maxReceiveBufferSize { return 0 } + return a.maxReceiveBufferSize - bytesQueued } -// OpenStream opens a stream -func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType PayloadProtocolIdentifier) (*Stream, error) { +// OpenStream opens a stream. +func (a *Association) OpenStream( + streamIdentifier uint16, + defaultPayloadType PayloadProtocolIdentifier, +) (*Stream, error) { a.lock.Lock() defer a.lock.Unlock() @@ -1560,18 +1603,19 @@ func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType Pay return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil } -// AcceptStream accepts a stream +// AcceptStream accepts a stream. func (a *Association) AcceptStream() (*Stream, error) { s, ok := <-a.acceptCh if !ok { return nil, io.EOF // no more incoming streams } + return s, nil } // createStream creates a stream. The caller should hold the lock and check no stream exists for this id. func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream { - s := &Stream{ + stream := &Stream{ association: a, streamIdentifier: streamIdentifier, reassemblyQueue: newReassemblyQueue(streamIdentifier), @@ -1580,30 +1624,36 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream writeDeadline: deadline.New(), } - s.readNotifier = sync.NewCond(&s.lock) + stream.readNotifier = sync.NewCond(&stream.lock) if accept { select { - case a.acceptCh <- s: - a.streams[streamIdentifier] = s + case a.acceptCh <- stream: + a.streams[streamIdentifier] = stream a.log.Debugf("[%s] accepted a new stream (streamIdentifier: %d)", a.name, streamIdentifier) default: a.log.Debugf("[%s] dropped a new stream (acceptCh size: %d)", a.name, len(a.acceptCh)) + return nil } } else { - a.streams[streamIdentifier] = s + a.streams[streamIdentifier] = stream } - return s + return stream } // getOrCreateStream gets or creates a stream. The caller should hold the lock. -func (a *Association) getOrCreateStream(streamIdentifier uint16, accept bool, defaultPayloadType PayloadProtocolIdentifier) *Stream { +func (a *Association) getOrCreateStream( + streamIdentifier uint16, + accept bool, + defaultPayloadType PayloadProtocolIdentifier, +) *Stream { if s, ok := a.streams[streamIdentifier]; ok { s.SetDefaultPayloadType(defaultPayloadType) + return s } @@ -1611,23 +1661,26 @@ func (a *Association) getOrCreateStream(streamIdentifier uint16, accept bool, de if s != nil { s.SetDefaultPayloadType(defaultPayloadType) } + return s } // The caller should hold the lock. -func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, uint32, error) { // nolint:gocognit +// +//nolint:gocognit,cyclop +func (a *Association) processSelectiveAck(selectiveAckChunk *chunkSelectiveAck) (map[uint16]int, uint32, error) { bytesAckedPerStream := map[uint16]int{} // New ack point, so pop all ACKed packets from inflightQueue // We add 1 because the "currentAckPoint" has already been popped from the inflight queue // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 - for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, d.cumulativeTSNAck); i++ { - c, ok := a.inflightQueue.pop(i) + for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, selectiveAckChunk.cumulativeTSNAck); i++ { + chunkPayload, ok := a.inflightQueue.pop(i) if !ok { return nil, 0, fmt.Errorf("%w: %v", ErrInflightQueueTSNPop, i) } - if !c.acked { + if !chunkPayload.acked { // RFC 4096 sec 6.3.2. Retransmission Timer Rules // R3) Whenever a SACK is received that acknowledges the DATA chunk // with the earliest outstanding TSN for that address, restart the @@ -1638,13 +1691,13 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, a.t3RTX.stop() } - nBytesAcked := len(c.userData) + nBytesAcked := len(chunkPayload.userData) // Sum the number of bytes acknowledged per stream - if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { - bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked + if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { + bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked } else { - bytesAckedPerStream[c.streamIdentifier] = nBytesAcked + bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked } // RFC 4960 sec 6.3.1. RTO Calculation @@ -1656,9 +1709,9 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, // packets that were retransmitted (and thus for which it is // ambiguous whether the reply was for the first instance of the // chunk or for a later instance) - if c.nSent == 1 && sna32GTE(c.tsn, a.minTSN2MeasureRTT) { + if chunkPayload.nSent == 1 && sna32GTE(chunkPayload.tsn, a.minTSN2MeasureRTT) { a.minTSN2MeasureRTT = a.myNextTSN - rtt := time.Since(c.since).Seconds() * 1000.0 + rtt := time.Since(chunkPayload.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.srtt.Store(srtt) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", @@ -1666,38 +1719,38 @@ func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, } } - if a.inFastRecovery && c.tsn == a.fastRecoverExitPoint { + if a.inFastRecovery && chunkPayload.tsn == a.fastRecoverExitPoint { a.log.Debugf("[%s] exit fast-recovery", a.name) a.inFastRecovery = false } } - htna := d.cumulativeTSNAck + htna := selectiveAckChunk.cumulativeTSNAck // Mark selectively acknowledged chunks as "acked" - for _, g := range d.gapAckBlocks { + for _, g := range selectiveAckChunk.gapAckBlocks { for i := g.start; i <= g.end; i++ { - tsn := d.cumulativeTSNAck + uint32(i) - c, ok := a.inflightQueue.get(tsn) + tsn := selectiveAckChunk.cumulativeTSNAck + uint32(i) + chunkPayload, ok := a.inflightQueue.get(tsn) if !ok { return nil, 0, fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) } - if !c.acked { + if !chunkPayload.acked { nBytesAcked := a.inflightQueue.markAsAcked(tsn) // Sum the number of bytes acknowledged per stream - if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { - bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked + if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { + bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked } else { - bytesAckedPerStream[c.streamIdentifier] = nBytesAcked + bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked } - a.log.Tracef("[%s] tsn=%d has been sacked", a.name, c.tsn) + a.log.Tracef("[%s] tsn=%d has been sacked", a.name, chunkPayload.tsn) - if c.nSent == 1 { + if chunkPayload.nSent == 1 { a.minTSN2MeasureRTT = a.myNextTSN - rtt := time.Since(c.since).Seconds() * 1000.0 + rtt := time.Since(chunkPayload.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.srtt.Store(srtt) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", @@ -1728,7 +1781,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { } // Update congestion control parameters - if a.CWND() <= a.ssthresh { + if a.CWND() <= a.ssthresh { //nolint:nestif // RFC 4096, sec 7.2.1. Slow-Start // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST // use the slow-start algorithm to increase cwnd only if the current @@ -1742,7 +1795,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // path MTU. if !a.inFastRecovery && a.pendingQueue.size() > 0 { - a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND())) + a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND())) //nolint:gosec // G115 // a.cwnd += min32(uint32(totalBytesAcked), a.MTU()) // SCTP way (slow) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) @@ -1757,7 +1810,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // partial_bytes_acked by the total number of bytes of all new chunks // acknowledged in that SACK including chunks acknowledged by the new // Cumulative TSN Ack and by Gap Ack Blocks. - a.partialBytesAcked += uint32(totalBytesAcked) + a.partialBytesAcked += uint32(totalBytesAcked) //nolint:gosec // G115 // o When partial_bytes_acked is equal to or greater than cwnd and // before the arrival of the SACK the sender had cwnd or more bytes @@ -1778,7 +1831,14 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { } // The caller should hold the lock. -func (a *Association) processFastRetransmission(cumTSNAckPoint uint32, gapAckBlocks []gapAckBlock, htna uint32, cumTSNAckPointAdvanced bool) error { +// +//nolint:cyclop +func (a *Association) processFastRetransmission( + cumTSNAckPoint uint32, + gapAckBlocks []gapAckBlock, + htna uint32, + cumTSNAckPointAdvanced bool, +) error { // HTNA algorithm - RFC 4960 Sec 7.2.4 // Increment missIndicator of each chunks that the SACK reported missing // when either of the following is met: @@ -1788,7 +1848,9 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint uint32, gapAckBlo // b) In fast-recovery AND the Cumulative TSN Ack Point advanced // the miss indications are incremented for all TSNs reported missing // in the SACK. - if !a.inFastRecovery || (a.inFastRecovery && cumTSNAckPointAdvanced) { + //nolint:nestif + if !a.inFastRecovery || + (a.inFastRecovery && cumTSNAckPointAdvanced) { var maxTSN uint32 if !a.inFastRecovery { // a) increment only for missing TSNs prior to the HTNA @@ -1836,8 +1898,13 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint uint32, gapAckBlo } // The caller should hold the lock. -func (a *Association) handleSack(d *chunkSelectiveAck) error { - a.log.Tracef("[%s] SACK: cumTSN=%d a_rwnd=%d", a.name, d.cumulativeTSNAck, d.advertisedReceiverWindowCredit) +// +//nolint:cyclop +func (a *Association) handleSack(selectiveAckChunk *chunkSelectiveAck) error { + a.log.Tracef( + "[%s] SACK: cumTSN=%d a_rwnd=%d", + a.name, selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.advertisedReceiverWindowCredit, + ) state := a.getState() if state != established && state != shutdownPending && state != shutdownReceived { return nil @@ -1845,7 +1912,7 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { a.stats.incSACKsReceived() - if sna32GT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { + if sna32GT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { // RFC 4960 sec 6.2.1. Processing a Received SACK // D) // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack @@ -1856,14 +1923,14 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { a.log.Debugf("[%s] SACK Cumulative ACK %v is older than ACK point %v", a.name, - d.cumulativeTSNAck, + selectiveAckChunk.cumulativeTSNAck, a.cumulativeTSNAckPoint) return nil } // Process selective ack - bytesAckedPerStream, htna, err := a.processSelectiveAck(d) + bytesAckedPerStream, htna, err := a.processSelectiveAck(selectiveAckChunk) if err != nil { return err } @@ -1874,13 +1941,13 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { } cumTSNAckPointAdvanced := false - if sna32LT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { + if sna32LT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { a.log.Tracef("[%s] SACK: cumTSN advanced: %d -> %d", a.name, a.cumulativeTSNAckPoint, - d.cumulativeTSNAck) + selectiveAckChunk.cumulativeTSNAck) - a.cumulativeTSNAckPoint = d.cumulativeTSNAck + a.cumulativeTSNAckPoint = selectiveAckChunk.cumulativeTSNAck cumTSNAckPointAdvanced = true a.onCumulativeTSNAckPointAdvanced(totalBytesAcked) } @@ -1901,14 +1968,16 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { // TSN Ack and the Gap Ack Blocks. // bytes acked were already subtracted by markAsAcked() method - bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) - if bytesOutstanding >= d.advertisedReceiverWindowCredit { + bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) //nolint:gosec // G115 + if bytesOutstanding >= selectiveAckChunk.advertisedReceiverWindowCredit { a.setRWND(0) } else { - a.setRWND(d.advertisedReceiverWindowCredit - bytesOutstanding) + a.setRWND(selectiveAckChunk.advertisedReceiverWindowCredit - bytesOutstanding) } - err = a.processFastRetransmission(d.cumulativeTSNAck, d.gapAckBlocks, htna, cumTSNAckPointAdvanced) + err = a.processFastRetransmission( + selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.gapAckBlocks, htna, cumTSNAckPointAdvanced, + ) if err != nil { return err } @@ -2060,7 +2129,10 @@ func (a *Association) createForwardTSN() *chunkForwardTSN { sequence: ssn, }) } - a.log.Tracef("[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr) + a.log.Tracef( + "[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", + a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr, + ) return fwdtsn } @@ -2077,34 +2149,35 @@ func (a *Association) createPacket(cs []chunk) *packet { } // The caller should hold the lock. -func (a *Association) handleReconfig(c *chunkReconfig) ([]*packet, error) { +func (a *Association) handleReconfig(reconfigChunk *chunkReconfig) ([]*packet, error) { a.log.Tracef("[%s] handleReconfig", a.name) pp := make([]*packet, 0) - p, err := a.handleReconfigParam(c.paramA) + pkt, err := a.handleReconfigParam(reconfigChunk.paramA) if err != nil { return nil, err } - if p != nil { - pp = append(pp, p) + if pkt != nil { + pp = append(pp, pkt) } - if c.paramB != nil { - p, err = a.handleReconfigParam(c.paramB) + if reconfigChunk.paramB != nil { + pkt, err = a.handleReconfigParam(reconfigChunk.paramB) if err != nil { return nil, err } - if p != nil { - pp = append(pp, p) + if pkt != nil { + pp = append(pp, pkt) } } + return pp, nil } // The caller should hold the lock. -func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { - a.log.Tracef("[%s] FwdTSN: %s", a.name, c.String()) +func (a *Association) handleForwardTSN(chunkTSN *chunkForwardTSN) []*packet { + a.log.Tracef("[%s] FwdTSN: %s", a.name, chunkTSN.String()) if !a.useForwardTSN { a.log.Warn("[%s] received FwdTSN but not enabled") @@ -2117,6 +2190,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{cerr} + return []*packet{outbound} } @@ -2129,12 +2203,13 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // duplicate may indicate the previous SACK was lost in the network. a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d", - a.name, c.newCumulativeTSN, a.peerLastTSN()) - if sna32LTE(c.newCumulativeTSN, a.peerLastTSN()) { + a.name, chunkTSN.newCumulativeTSN, a.peerLastTSN()) + if sna32LTE(chunkTSN.newCumulativeTSN, a.peerLastTSN()) { a.log.Tracef("[%s] sending ack on Forward TSN", a.name) a.ackState = ackStateImmediate a.ackTimer.stop() a.awakeWriteLoop() + return nil } @@ -2149,14 +2224,14 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // chunk, // Advance peerLastTSN - for sna32LT(a.peerLastTSN(), c.newCumulativeTSN) { + for sna32LT(a.peerLastTSN(), chunkTSN.newCumulativeTSN) { a.payloadQueue.pop(true) // may not exist } // Report new peerLastTSN value and abandoned largest SSN value to // corresponding streams so that the abandoned chunks can be removed // from the reassemblyQueue. - for _, forwarded := range c.streams { + for _, forwarded := range chunkTSN.streams { if s, ok := a.streams[forwarded.identifier]; ok { s.handleForwardTSNForOrdered(forwarded.sequence) } @@ -2168,7 +2243,7 @@ func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { // unordered chunks. // See https://github.com/pion/sctp/issues/106 for _, s := range a.streams { - s.handleForwardTSNForUnordered(c.newCumulativeTSN) + s.handleForwardTSNForUnordered(chunkTSN.newCumulativeTSN) } return a.handlePeerLastTSNAndAcknowledgement(false) @@ -2195,15 +2270,16 @@ func (a *Association) sendResetRequest(streamIdentifier uint16) error { a.pendingQueue.push(c) a.awakeWriteLoop() + return nil } // The caller should hold the lock. func (a *Association) handleReconfigParam(raw param) (*packet, error) { - switch p := raw.(type) { + switch par := raw.(type) { case *paramOutgoingResetRequest: a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name) - if a.peerLastTSN() < p.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { + if a.peerLastTSN() < par.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { // We have too many reconfig requests outstanding. Drop the request and let // the peer retransmit. A well behaved peer should only have 1 outstanding // reconfig request. @@ -2212,47 +2288,51 @@ func (a *Association) handleReconfigParam(raw param) (*packet, error) { // At any given time, there MUST NOT be more than one request in flight. // So, if the Re-configuration Timer is running and the RE-CONFIG chunk // contains at least one request parameter, the chunk MUST be buffered. - // chrome: https://chromium.googlesource.com/external/webrtc/+/refs/heads/main/net/dcsctp/socket/stream_reset_handler.cc#271 + // chrome: + // https://chromium.googlesource.com/external/webrtc/+/refs/heads/main/net/dcsctp/socket/stream_reset_handler.cc#271 return nil, fmt.Errorf("%w: %d", ErrTooManyReconfigRequests, len(a.reconfigRequests)) } - a.reconfigRequests[p.reconfigRequestSequenceNumber] = p - resp := a.resetStreamsIfAny(p) + a.reconfigRequests[par.reconfigRequestSequenceNumber] = par + resp := a.resetStreamsIfAny(par) if resp != nil { return resp, nil } + return nil, nil //nolint:nilnil case *paramReconfigResponse: a.log.Tracef("[%s] handleReconfigParam (ReconfigResponse)", a.name) - if p.result == reconfigResultInProgress { + if par.result == reconfigResultInProgress { // RFC 6525: https://www.rfc-editor.org/rfc/rfc6525.html#section-5.2.7 // // If the Result field indicates "In progress", the timer for the // Re-configuration Request Sequence Number is started again. If // the timer runs out, the RE-CONFIG chunk MUST be retransmitted // but the corresponding error counters MUST NOT be incremented. - if _, ok := a.reconfigs[p.reconfigResponseSequenceNumber]; ok { + if _, ok := a.reconfigs[par.reconfigResponseSequenceNumber]; ok { a.tReconfig.stop() a.tReconfig.start(a.rtoMgr.getRTO()) } + return nil, nil //nolint:nilnil } - delete(a.reconfigs, p.reconfigResponseSequenceNumber) + delete(a.reconfigs, par.reconfigResponseSequenceNumber) if len(a.reconfigs) == 0 { a.tReconfig.stop() } + return nil, nil //nolint:nilnil default: - return nil, fmt.Errorf("%w: %t", ErrParamterType, p) + return nil, fmt.Errorf("%w: %t", ErrParamterType, par) } } // The caller should hold the lock. -func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { +func (a *Association) resetStreamsIfAny(resetRequest *paramOutgoingResetRequest) *packet { result := reconfigResultSuccessPerformed - if sna32LTE(p.senderLastTSN, a.peerLastTSN()) { + if sna32LTE(resetRequest.senderLastTSN, a.peerLastTSN()) { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", - a.name, p.senderLastTSN, a.peerLastTSN()) - for _, id := range p.streamIdentifiers { + a.name, resetRequest.senderLastTSN, a.peerLastTSN()) + for _, id := range resetRequest.streamIdentifiers { s, ok := a.streams[id] if !ok { continue @@ -2263,16 +2343,16 @@ func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { a.log.Debugf("[%s] deleting stream %d", a.name, id) delete(a.streams, s.streamIdentifier) } - delete(a.reconfigRequests, p.reconfigRequestSequenceNumber) + delete(a.reconfigRequests, resetRequest.reconfigRequestSequenceNumber) } else { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", - a.name, p.senderLastTSN, a.peerLastTSN()) + a.name, resetRequest.senderLastTSN, a.peerLastTSN()) result = reconfigResultInProgress } return a.createPacket([]chunk{&chunkReconfig{ paramA: ¶mReconfigResponse{ - reconfigResponseSequenceNumber: p.reconfigRequestSequenceNumber, + reconfigResponseSequenceNumber: resetRequest.reconfigRequestSequenceNumber, result: result, }, }}) @@ -2280,38 +2360,49 @@ func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { // Move the chunk peeked with a.pendingQueue.peek() to the inflightQueue. // The caller should hold the lock. -func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) { - if err := a.pendingQueue.pop(c); err != nil { +func (a *Association) movePendingDataChunkToInflightQueue(chunkPayload *chunkPayloadData) { + if err := a.pendingQueue.pop(chunkPayload); err != nil { a.log.Errorf("[%s] failed to pop from pending queue: %s", a.name, err.Error()) } // Mark all fragements are in-flight now - if c.endingFragment { - c.setAllInflight() + if chunkPayload.endingFragment { + chunkPayload.setAllInflight() } // Assign TSN - c.tsn = a.generateNextTSN() + chunkPayload.tsn = a.generateNextTSN() - c.since = time.Now() // use to calculate RTT and also for maxPacketLifeTime - c.nSent = 1 // being sent for the first time + chunkPayload.since = time.Now() // use to calculate RTT and also for maxPacketLifeTime + chunkPayload.nSent = 1 // being sent for the first time - a.checkPartialReliabilityStatus(c) + a.checkPartialReliabilityStatus(chunkPayload) - a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", - a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment) + a.log.Tracef( + "[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", + a.name, + chunkPayload.payloadType, + chunkPayload.tsn, + chunkPayload.streamSequenceNumber, + chunkPayload.nSent, + len(chunkPayload.userData), + chunkPayload.beginningFragment, + chunkPayload.endingFragment, + ) - a.inflightQueue.pushNoCheck(c) + a.inflightQueue.pushNoCheck(chunkPayload) } // popPendingDataChunksToSend pops chunks from the pending queues as many as // the cwnd and rwnd allows to send. // The caller should hold the lock. +// +//nolint:cyclop func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint16) { chunks := []*chunkPayloadData{} var sisToReset []uint16 // stream identifieres to reset - if a.pendingQueue.size() > 0 { + if a.pendingQueue.size() > 0 { //nolint:nestif // RFC 4960 sec 6.1. Transmission of DATA Chunks // A) At any given time, the data sender MUST NOT transmit new data to // any destination transport address if its peer's rwnd indicates @@ -2321,22 +2412,23 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 // the receiver if allowed by cwnd (see rule B, below). for { - c := a.pendingQueue.peek() - if c == nil { + chunkPayload := a.pendingQueue.peek() + if chunkPayload == nil { break // no more pending data } - dataLen := uint32(len(c.userData)) + dataLen := uint32(len(chunkPayload.userData)) //nolint:gosec // G115 if dataLen == 0 { - sisToReset = append(sisToReset, c.streamIdentifier) - err := a.pendingQueue.pop(c) + sisToReset = append(sisToReset, chunkPayload.streamIdentifier) + err := a.pendingQueue.pop(chunkPayload) if err != nil { a.log.Errorf("failed to pop from pending queue: %s", err.Error()) } + continue } - if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() { + if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() { //nolint:gosec // G115 break // would exceeds cwnd } @@ -2346,8 +2438,8 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 a.setRWND(a.RWND() - dataLen) - a.movePendingDataChunkToInflightQueue(c) - chunks = append(chunks, c) + a.movePendingDataChunkToInflightQueue(chunkPayload) + chunks = append(chunks, chunkPayload) } // the data sender can always have one DATA chunk in flight to the receiver @@ -2378,20 +2470,20 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) [] chunksToSend := []chunk{} bytesInPacket := int(commonHeaderSize) - for _, c := range chunks { + for _, chunkPayload := range chunks { // RFC 4960 sec 6.1. Transmission of DATA Chunks // Multiple DATA chunks committed for transmission MAY be bundled in a // single packet. Furthermore, DATA chunks being retransmitted MAY be // bundled with new DATA chunks, as long as the resulting packet size // does not exceed the path MTU. - chunkSizeInPacket := int(dataChunkHeaderSize) + len(c.userData) + chunkSizeInPacket := int(dataChunkHeaderSize) + len(chunkPayload.userData) chunkSizeInPacket += getPadding(chunkSizeInPacket) if bytesInPacket+chunkSizeInPacket > int(a.MTU()) { packets = append(packets, a.createPacket(chunksToSend)) chunksToSend = []chunk{} bytesInPacket = int(commonHeaderSize) } - chunksToSend = append(chunksToSend, c) + chunksToSend = append(chunksToSend, chunkPayload) bytesInPacket += chunkSizeInPacket } @@ -2409,6 +2501,7 @@ func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloa state := a.getState() if state != established { a.lock.Unlock() + return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, getAssociationStateString(state)) } @@ -2433,11 +2526,12 @@ func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloa a.lock.Unlock() a.awakeWriteLoop() + return nil } // The caller should hold the lock. -func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { +func (a *Association) checkPartialReliabilityStatus(chunkPayload *chunkPayloadData) { if !a.useForwardTSN { return } @@ -2447,29 +2541,35 @@ func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { // All Data Channel Establishment Protocol messages MUST be sent using // ordered delivery and reliable transmission. // - if c.payloadType == PayloadTypeWebRTCDCEP { + if chunkPayload.payloadType == PayloadTypeWebRTCDCEP { return } // PR-SCTP - if s, ok := a.streams[c.streamIdentifier]; ok { - s.lock.RLock() - if s.reliabilityType == ReliabilityTypeRexmit { - if c.nSent >= s.reliabilityValue { - c.setAbandoned(true) - a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", a.name, c.tsn, c.payloadType, c.nSent) + if stream, ok := a.streams[chunkPayload.streamIdentifier]; ok { //nolint:nestif + stream.lock.RLock() + if stream.reliabilityType == ReliabilityTypeRexmit { + if chunkPayload.nSent >= stream.reliabilityValue { + chunkPayload.setAbandoned(true) + a.log.Tracef( + "[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", + a.name, chunkPayload.tsn, chunkPayload.payloadType, chunkPayload.nSent, + ) } - } else if s.reliabilityType == ReliabilityTypeTimed { - elapsed := int64(time.Since(c.since).Seconds() * 1000) - if elapsed >= int64(s.reliabilityValue) { - c.setAbandoned(true) - a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", a.name, c.tsn, c.payloadType, elapsed) + } else if stream.reliabilityType == ReliabilityTypeTimed { + elapsed := int64(time.Since(chunkPayload.since).Seconds() * 1000) + if elapsed >= int64(stream.reliabilityValue) { + chunkPayload.setAbandoned(true) + a.log.Tracef( + "[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", + a.name, chunkPayload.tsn, chunkPayload.payloadType, elapsed, + ) } } - s.lock.RUnlock() + stream.lock.RUnlock() } else { // Remote has reset its send side of the stream, we can still send data. - a.log.Tracef("[%s] stream %d not found, remote reset", a.name, c.streamIdentifier) + a.log.Tracef("[%s] stream %d not found, remote reset", a.name, chunkPayload.streamIdentifier) } } @@ -2483,34 +2583,37 @@ func (a *Association) getDataPacketsToRetransmit() []*packet { var done bool for i := 0; !done; i++ { - c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) + chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 if !ok { break // end of pending data } - if !c.retransmit { + if !chunkPayload.retransmit { continue } - if i == 0 && int(a.RWND()) < len(c.userData) { + if i == 0 && int(a.RWND()) < len(chunkPayload.userData) { // Send it as a zero window probe done = true - } else if bytesToSend+len(c.userData) > int(awnd) { + } else if bytesToSend+len(chunkPayload.userData) > int(awnd) { break } // reset the retransmit flag not to retransmit again before the next // t3-rtx timer fires - c.retransmit = false - bytesToSend += len(c.userData) + chunkPayload.retransmit = false + bytesToSend += len(chunkPayload.userData) - c.nSent++ + chunkPayload.nSent++ - a.checkPartialReliabilityStatus(c) + a.checkPartialReliabilityStatus(chunkPayload) - a.log.Tracef("[%s] retransmitting tsn=%d ssn=%d sent=%d", a.name, c.tsn, c.streamSequenceNumber, c.nSent) + a.log.Tracef( + "[%s] retransmitting tsn=%d ssn=%d sent=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, chunkPayload.nSent, + ) - chunks = append(chunks, c) + chunks = append(chunks, chunkPayload) } return a.bundleDataChunksIntoPackets(chunks) @@ -2521,6 +2624,7 @@ func (a *Association) getDataPacketsToRetransmit() []*packet { func (a *Association) generateNextTSN() uint32 { tsn := a.myNextTSN a.myNextTSN++ + return tsn } @@ -2529,6 +2633,7 @@ func (a *Association) generateNextTSN() uint32 { func (a *Association) generateNextRSN() uint32 { rsn := a.myNextRSN a.myNextRSN++ + return rsn } @@ -2538,6 +2643,7 @@ func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() sack.duplicateTSN = a.payloadQueue.popDuplicates() sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks() + return sack } @@ -2570,71 +2676,72 @@ func (a *Association) handleChunksEnd() { } } -func (a *Association) handleChunk(p *packet, c chunk) error { +func (a *Association) handleChunk(receivedPacket *packet, receivedChunk chunk) error { //nolint:cyclop a.lock.Lock() defer a.lock.Unlock() var packets []*packet var err error - if _, err = c.check(); err != nil { + if _, err = receivedChunk.check(); err != nil { a.log.Errorf("[%s] failed validating chunk: %s ", a.name, err) + return nil } isAbort := false - switch c := c.(type) { + switch receivedChunk := receivedChunk.(type) { // Note: We do not do the following for chunkInit, chunkInitAck, and chunkCookieEcho: // If an endpoint receives an INIT, INIT ACK, or COOKIE ECHO chunk but decides not to establish the // new association due to missing mandatory parameters in the received INIT or INIT ACK chunk, invalid // parameter values, or lack of local resources, it SHOULD respond with an ABORT chunk. case *chunkInit: - packets, err = a.handleInit(p, c) + packets, err = a.handleInit(receivedPacket, receivedChunk) case *chunkInitAck: - err = a.handleInitAck(p, c) + err = a.handleInitAck(receivedPacket, receivedChunk) case *chunkAbort: isAbort = true - err = a.handleAbort(c) + err = a.handleAbort(receivedChunk) case *chunkError: var errStr string - for _, e := range c.errorCauses { + for _, e := range receivedChunk.errorCauses { errStr += fmt.Sprintf("(%s)", e) } a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) // Note: chunkHeartbeatAck not handled? case *chunkHeartbeat: - packets = a.handleHeartbeat(c) + packets = a.handleHeartbeat(receivedChunk) case *chunkCookieEcho: - packets = a.handleCookieEcho(c) + packets = a.handleCookieEcho(receivedChunk) case *chunkCookieAck: a.handleCookieAck() case *chunkPayloadData: - packets = a.handleData(c) + packets = a.handleData(receivedChunk) case *chunkSelectiveAck: - err = a.handleSack(c) + err = a.handleSack(receivedChunk) case *chunkReconfig: - packets, err = a.handleReconfig(c) + packets, err = a.handleReconfig(receivedChunk) case *chunkForwardTSN: - packets = a.handleForwardTSN(c) + packets = a.handleForwardTSN(receivedChunk) case *chunkShutdown: - a.handleShutdown(c) + a.handleShutdown(receivedChunk) case *chunkShutdownAck: - a.handleShutdownAck(c) + a.handleShutdownAck(receivedChunk) case *chunkShutdownComplete: - err = a.handleShutdownComplete(c) + err = a.handleShutdownComplete(receivedChunk) default: err = ErrChunkTypeUnhandled @@ -2647,6 +2754,7 @@ func (a *Association) handleChunk(p *packet, c chunk) error { } a.log.Errorf("Failed to handle chunk: %v", err) + return nil } @@ -2658,7 +2766,7 @@ func (a *Association) handleChunk(p *packet, c chunk) error { return nil } -func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { +func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { //nolint:cyclop a.lock.Lock() defer a.lock.Unlock() @@ -2673,6 +2781,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { if err != nil { a.log.Debugf("[%s] failed to retransmit init (nRtos=%d): %v", a.name, nRtos, err) } + return } @@ -2681,6 +2790,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { if err != nil { a.log.Debugf("[%s] failed to retransmit cookie-echo (nRtos=%d): %v", a.name, nRtos, err) } + return } @@ -2698,7 +2808,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { } } - if id == timerT3RTX { + if id == timerT3RTX { //nolint:nestif a.stats.incT3Timeouts() // RFC 4960 sec 6.3.3 @@ -2772,17 +2882,20 @@ func (a *Association) onRetransmissionFailure(id int) { if id == timerT1Init { a.log.Errorf("[%s] retransmission failure: T1-init", a.name) a.completeHandshake(ErrHandshakeInitAck) + return } if id == timerT1Cookie { a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name) a.completeHandshake(ErrHandshakeCookieEcho) + return } if id == timerT2Shutdown { a.log.Errorf("[%s] retransmission failure: T2-shutdown", a.name) + return } @@ -2792,6 +2905,7 @@ func (a *Association) onRetransmissionFailure(id int) { // * ICE would fail if the connectivity is lost // * WebRTC spec is not clear how this incident should be reported to ULP a.log.Errorf("[%s] retransmission failure: T3-rtx (DATA)", a.name) + return } } @@ -2837,5 +2951,6 @@ func (a *Association) completeHandshake(handshakeErr error) bool { case <-a.closeWriteLoopCh: // check the read/write sides for closure case <-a.readLoopCloseCh: } + return false } diff --git a/association_test.go b/association_test.go index 5f1c2ed0..ed3f54c0 100644 --- a/association_test.go +++ b/association_test.go @@ -51,6 +51,8 @@ func TestAssocStressDuplex(t *testing.T) { } func stressDuplex(t *testing.T) { + t.Helper() + ca, cb, stop, err := pipe(pipeDump) if err != nil { t.Fatal(err) @@ -91,6 +93,8 @@ func pipe(piper piperFunc) (*Stream, *Stream, func(*testing.T), error) { } stop := func(t *testing.T) { + t.Helper() + err = sa.Close() if err != nil { t.Error(err) @@ -120,7 +124,7 @@ func association(piper piperFunc) (*Association, *Association, error) { err error } - c := make(chan result) + resultCh := make(chan result) loggerFactory := logging.NewDefaultLoggerFactory() // Setup client @@ -129,7 +133,7 @@ func association(piper piperFunc) (*Association, *Association, error) { NetConn: ca, LoggerFactory: loggerFactory, }) - c <- result{client, err} + resultCh <- result{client, err} }() // Setup server @@ -142,7 +146,7 @@ func association(piper piperFunc) (*Association, *Association, error) { } // Receive client - res := <-c + res := <-resultCh if res.err != nil { return nil, nil, res.err } @@ -189,12 +193,13 @@ type dumbConn struct { func acceptDumbConn() *dumbConn { pConn, err := net.ListenUDP("udp4", nil) check(err) + return &dumbConn{ pConn: pConn, } } -// Read +// Read. func (c *dumbConn) Read(p []byte) (int, error) { i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { @@ -208,47 +213,54 @@ func (c *dumbConn) Read(p []byte) (int, error) { return i, err } -// Write writes len(p) bytes from p to the DTLS connection +// Write writes len(p) bytes from p to the DTLS connection. func (c *dumbConn) Write(p []byte) (n int, err error) { return c.pConn.WriteTo(p, c.RemoteAddr()) } -// Close closes the conn and releases any Read calls +// Close closes the conn and releases any Read calls. func (c *dumbConn) Close() error { return c.pConn.Close() } -// LocalAddr is a stub +// LocalAddr is a stub. func (c *dumbConn) LocalAddr() net.Addr { if c.pConn != nil { return c.pConn.LocalAddr() } + return nil } -// RemoteAddr is a stub +// RemoteAddr is a stub. func (c *dumbConn) RemoteAddr() net.Addr { c.mu.RLock() defer c.mu.RUnlock() + return c.rAddr } -// SetDeadline is a stub +// SetDeadline is a stub. func (c *dumbConn) SetDeadline(time.Time) error { return nil } -// SetReadDeadline is a stub +// SetReadDeadline is a stub. func (c *dumbConn) SetReadDeadline(time.Time) error { return nil } -// SetWriteDeadline is a stub +// SetWriteDeadline is a stub. func (c *dumbConn) SetWriteDeadline(time.Time) error { return nil } -func createNewAssociationPair(br *test.Bridge, ackMode int, recvBufSize uint32) (*Association, *Association, error) { +//nolint:cyclop +func createNewAssociationPair( + br *test.Bridge, + ackMode int, + recvBufSize uint32, +) (*Association, *Association, error) { var a0, a1 *Association var err0, err1 error loggerFactory := logging.NewDefaultLoggerFactory() @@ -416,14 +428,14 @@ func establishSessionPair(br *test.Bridge, a0, a1 *Association, si uint16) (*Str return s0, s1, nil } -func TestAssocReliable(t *testing.T) { +func TestAssocReliable(t *testing.T) { //nolint:cyclop,maintidx // sbuf - small enough not to be fragmented // large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) //nolint:staticcheck // TODO: remove? rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) // sbufL - large enough to be fragmented into two chunks and each chunks are @@ -432,7 +444,7 @@ func TestAssocReliable(t *testing.T) { for i := 0; i < len(sbufL); i++ { sbufL[i] = byte(i & 0xff) } - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) //nolint:staticcheck // TODO: remove? rand.Shuffle(len(sbufL), func(i, j int) { sbufL[i], sbufL[j] = sbufL[j], sbufL[i] }) t.Run("Simple", func(t *testing.T) { @@ -825,7 +837,7 @@ func TestAssocReliable(t *testing.T) { }) } -func TestAssocUnreliable(t *testing.T) { +func TestAssocUnreliable(t *testing.T) { //nolint:cyclop,maintidx // sbuf1, sbuf2: // large enough to be fragmented into two chunks and each chunks are // large enough not to be bundled @@ -834,12 +846,12 @@ func TestAssocUnreliable(t *testing.T) { for i := 0; i < len(sbuf1); i++ { sbuf1[i] = byte(i & 0xff) } - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) //nolint:staticcheck // TODO: remove? rand.Shuffle(len(sbuf1), func(i, j int) { sbuf1[i], sbuf1[j] = sbuf1[j], sbuf1[i] }) for i := 0; i < len(sbuf2); i++ { sbuf2[i] = byte(i & 0xff) } - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) //nolint:staticcheck // TODO: remove? rand.Shuffle(len(sbuf2), func(i, j int) { sbuf2[i], sbuf2[j] = sbuf2[j], sbuf2[i] }) // sbuf - small enough not to be fragmented @@ -848,7 +860,7 @@ func TestAssocUnreliable(t *testing.T) { for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) //nolint:staticcheck // TODO: remove? rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) t.Run("Rexmit ordered no fragment", func(t *testing.T) { // nolint:dupl @@ -1185,7 +1197,7 @@ func TestAssocUnreliable(t *testing.T) { // A test for this PR https://github.com/pion/sctp/pull/341 // We drop the first INIT ACK, and we expect the verification tag to be 0 on // retransmission. -func TestInitVerificationTagIsZero(t *testing.T) { +func TestInitVerificationTagIsZero(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -1230,6 +1242,7 @@ func TestInitVerificationTagIsZero(t *testing.T) { // Drop the first two Init Ack chunk. case *chunkInitAck: ackCount++ + return ackCount > 2 } @@ -1324,14 +1337,14 @@ func TestCreateForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward one abandoned", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.cumulativeTSNAckPoint = 9 - a.advancedPeerTSNAckPoint = 10 - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + assoc.cumulativeTSNAckPoint = 9 + assoc.advancedPeerTSNAckPoint = 10 + assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, @@ -1342,7 +1355,7 @@ func TestCreateForwardTSN(t *testing.T) { _abandoned: true, }) - fwdtsn := a.createForwardTSN() + fwdtsn := assoc.createForwardTSN() assert.Equal(t, uint32(10), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 1, len(fwdtsn.streams), "there should be one stream") @@ -1351,14 +1364,14 @@ func TestCreateForwardTSN(t *testing.T) { }) t.Run("forward two abandoned with the same SI", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.cumulativeTSNAckPoint = 9 - a.advancedPeerTSNAckPoint = 12 - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + assoc.cumulativeTSNAckPoint = 9 + assoc.advancedPeerTSNAckPoint = 12 + assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, @@ -1368,7 +1381,7 @@ func TestCreateForwardTSN(t *testing.T) { nSent: 1, _abandoned: true, }) - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 11, @@ -1378,7 +1391,7 @@ func TestCreateForwardTSN(t *testing.T) { nSent: 1, _abandoned: true, }) - a.inflightQueue.pushNoCheck(&chunkPayloadData{ + assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 12, @@ -1389,7 +1402,7 @@ func TestCreateForwardTSN(t *testing.T) { _abandoned: true, }) - fwdtsn := a.createForwardTSN() + fwdtsn := assoc.createForwardTSN() assert.Equal(t, uint32(12), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 2, len(fwdtsn.streams), "there should be two stream") @@ -1417,116 +1430,116 @@ func TestHandleForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward 3 unreceived chunks", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.useForwardTSN = true - prevTSN := a.peerLastTSN() + assoc.useForwardTSN = true + prevTSN := assoc.peerLastTSN() fwdtsn := &chunkForwardTSN{ newCumulativeTSN: prevTSN + 3, streams: []chunkForwardTSNStream{{identifier: 0, sequence: 0}}, } - p := a.handleForwardTSN(fwdtsn) + p := assoc.handleForwardTSN(fwdtsn) - a.lock.Lock() - delayedAckTriggered := a.delayedAckTriggered - immediateAckTriggered := a.immediateAckTriggered - a.lock.Unlock() - assert.Equal(t, a.peerLastTSN(), prevTSN+3, "peerLastTSN should advance by 3 ") + assoc.lock.Lock() + delayedAckTriggered := assoc.delayedAckTriggered + immediateAckTriggered := assoc.immediateAckTriggered + assoc.lock.Unlock() + assert.Equal(t, assoc.peerLastTSN(), prevTSN+3, "peerLastTSN should advance by 3 ") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 1 missing", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.useForwardTSN = true - prevTSN := a.peerLastTSN() + assoc.useForwardTSN = true + prevTSN := assoc.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(a.peerLastTSN() + 2) + assoc.payloadQueue.push(assoc.peerLastTSN() + 2) fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN() + 1, + newCumulativeTSN: assoc.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } - p := a.handleForwardTSN(fwdtsn) + p := assoc.handleForwardTSN(fwdtsn) - a.lock.Lock() - delayedAckTriggered := a.delayedAckTriggered - immediateAckTriggered := a.immediateAckTriggered - a.lock.Unlock() - assert.Equal(t, a.peerLastTSN(), prevTSN+2, "peerLastTSN should advance by 3") + assoc.lock.Lock() + delayedAckTriggered := assoc.delayedAckTriggered + immediateAckTriggered := assoc.immediateAckTriggered + assoc.lock.Unlock() + assert.Equal(t, assoc.peerLastTSN(), prevTSN+2, "peerLastTSN should advance by 3") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 2 missing", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.useForwardTSN = true - prevTSN := a.peerLastTSN() + assoc.useForwardTSN = true + prevTSN := assoc.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 - a.payloadQueue.push(a.peerLastTSN() + 3) + assoc.payloadQueue.push(assoc.peerLastTSN() + 3) fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN() + 1, + newCumulativeTSN: assoc.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } - p := a.handleForwardTSN(fwdtsn) + p := assoc.handleForwardTSN(fwdtsn) - a.lock.Lock() - immediateAckTriggered := a.immediateAckTriggered - a.lock.Unlock() - assert.Equal(t, a.peerLastTSN(), prevTSN+1, "peerLastTSN should advance by 1") + assoc.lock.Lock() + immediateAckTriggered := assoc.immediateAckTriggered + assoc.lock.Unlock() + assert.Equal(t, assoc.peerLastTSN(), prevTSN+1, "peerLastTSN should advance by 1") assert.True(t, immediateAckTriggered, "immediate sack should be triggered") assert.Nil(t, p, "should return nil") }) t.Run("dup forward TSN chunk should generate sack", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.useForwardTSN = true - prevTSN := a.peerLastTSN() + assoc.useForwardTSN = true + prevTSN := assoc.peerLastTSN() fwdtsn := &chunkForwardTSN{ - newCumulativeTSN: a.peerLastTSN(), // old TSN + newCumulativeTSN: assoc.peerLastTSN(), // old TSN streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } - p := a.handleForwardTSN(fwdtsn) + p := assoc.handleForwardTSN(fwdtsn) - a.lock.Lock() - ackState := a.ackState - a.lock.Unlock() - assert.Equal(t, a.peerLastTSN(), prevTSN, "peerLastTSN should not advance") + assoc.lock.Lock() + ackState := assoc.ackState + assoc.lock.Unlock() + assert.Equal(t, assoc.peerLastTSN(), prevTSN, "peerLastTSN should not advance") assert.Equal(t, ackStateImmediate, ackState, "sack should be requested") assert.Nil(t, p, "should return nil") }) } -func TestAssocT1InitTimer(t *testing.T) { +func TestAssocT1InitTimer(t *testing.T) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { @@ -1655,7 +1668,7 @@ func TestAssocT1InitTimer(t *testing.T) { }) } -func TestAssocT1CookieTimer(t *testing.T) { +func TestAssocT1CookieTimer(t *testing.T) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { @@ -1765,6 +1778,7 @@ func TestAssocT1CookieTimer(t *testing.T) { return true } } + return true }) @@ -1795,32 +1809,32 @@ func TestAssocCreateNewStream(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("acceptChSize", func(t *testing.T) { - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) for i := 0; i < acceptChSize; i++ { - s := a.createStream(uint16(i), true) - _, ok := a.streams[s.streamIdentifier] + s := assoc.createStream(uint16(i), true) //nolint:gosec + _, ok := assoc.streams[s.streamIdentifier] assert.True(t, ok, "should be in a.streams map") } newSI := uint16(acceptChSize) - s := a.createStream(newSI, true) + s := assoc.createStream(newSI, true) assert.Nil(t, s, "should be nil") - _, ok := a.streams[newSI] + _, ok := assoc.streams[newSI] assert.False(t, ok, "should NOT be in a.streams map") toBeIgnored := &chunkPayloadData{ beginningFragment: true, endingFragment: true, - tsn: a.peerLastTSN() + 1, + tsn: assoc.peerLastTSN() + 1, streamIdentifier: newSI, userData: []byte("ABC"), } - p := a.handleData(toBeIgnored) + p := assoc.handleData(toBeIgnored) assert.Nil(t, p, "should be nil") }) } @@ -1882,7 +1896,7 @@ func TestAssocT3RtxTimer(t *testing.T) { }) } -func TestAssocCongestionControl(t *testing.T) { +func TestAssocCongestionControl(t *testing.T) { //nolint:cyclop,maintidx // sbuf - large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { @@ -1913,7 +1927,7 @@ func TestAssocCongestionControl(t *testing.T) { br.DropNextNWrites(0, 1) // drop the next write for i := 0; i < 4; i++ { - binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number + binary.BigEndian.PutUint32(sbuf, uint32(i)) //nolint:gosec // G115 uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") @@ -1990,7 +2004,7 @@ func TestAssocCongestionControl(t *testing.T) { a1.stats.reset() for i := 0; i < nPacketsToSend; i++ { - binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number + binary.BigEndian.PutUint32(sbuf, uint32(i)) //nolint:gosec // G115 uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") @@ -2074,7 +2088,7 @@ func TestAssocCongestionControl(t *testing.T) { assert.Nil(t, err, "failed to establish session pair") for i := 0; i < nPacketsToSend; i++ { - binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number + binary.BigEndian.PutUint32(sbuf, uint32(i)) // nolint:gosec // G115 uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") @@ -2217,7 +2231,12 @@ func TestAssocDelayedAck(t *testing.T) { t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) assert.Equal(t, uint64(1), a1.stats.getNumDATAs(), "DATA chunk count mismatch") - assert.Equal(t, a0.stats.getNumSACKsReceived(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks") + assert.Equal( + t, + a0.stats.getNumSACKsReceived(), + a1.stats.getNumDATAs(), + "sack count should be equal to the number of data chunks", + ) assert.Equal(t, uint64(1), a1.stats.getNumAckTimeouts(), "ackTimeout count mismatch") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") @@ -2226,6 +2245,8 @@ func TestAssocDelayedAck(t *testing.T) { } func checkGoroutineLeaks(t *testing.T) { + t.Helper() + // Get the count of goroutines at the start of the test. initialGoroutines := runtime.NumGoroutine() // Register a cleanup function to run after the test completes. @@ -2243,7 +2264,7 @@ func checkGoroutineLeaks(t *testing.T) { }) } -func TestAssocReset(t *testing.T) { +func TestAssocReset(t *testing.T) { //nolint:cyclop t.Run("Close one way", func(t *testing.T) { checkGoroutineLeaks(t) @@ -2285,6 +2306,7 @@ func TestAssocReset(t *testing.T) { n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err + return } @@ -2299,6 +2321,7 @@ func TestAssocReset(t *testing.T) { select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") + break loop default: } @@ -2350,6 +2373,7 @@ func TestAssocReset(t *testing.T) { n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err + return } @@ -2364,6 +2388,7 @@ func TestAssocReset(t *testing.T) { select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") + break loop0 default: } @@ -2381,6 +2406,7 @@ func TestAssocReset(t *testing.T) { assert.Equal(t, io.EOF, err, "should be EOF") if err != nil { doneCh <- err + return } } @@ -2482,6 +2508,7 @@ func (c *fakeEchoConn) Read(b []byte) (int, error) { return len(r), nil } + return 0, io.EOF } @@ -2495,6 +2522,7 @@ func (c *fakeEchoConn) Write(b []byte) (int, error) { } c.echo <- b c.bytesSent += uint64(len(b)) + return len(b), nil } @@ -2503,6 +2531,7 @@ func (c *fakeEchoConn) Close() error { defer c.mu.Unlock() close(c.echo) close(c.closed) + return c.errClose } func (c *fakeEchoConn) LocalAddr() net.Addr { return nil } @@ -2517,23 +2546,23 @@ func TestRoutineLeak(t *testing.T) { checkGoroutineLeaks(t) conn := newFakeEchoConn(io.EOF) - a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) + assoc, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done - err = a.Close() + err = assoc.Close() assert.Equal(t, io.EOF, err, "Close() should fail with EOF") select { - case _, ok := <-a.closeWriteLoopCh: + case _, ok := <-assoc.closeWriteLoopCh: if ok { t.Errorf("closeWriteLoopCh is expected to be closed, but received signal") } default: t.Errorf("closeWriteLoopCh is expected to be closed, but not") } - _ = a + _ = assoc }) t.Run("Connection closed by remote host", func(t *testing.T) { checkGoroutineLeaks(t) @@ -2564,7 +2593,7 @@ func TestStats(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() conn := newFakeEchoConn(nil) - a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) + assoc, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done @@ -2573,23 +2602,25 @@ func TestStats(t *testing.T) { conn.mu.Lock() defer conn.mu.Unlock() - assert.Equal(t, conn.bytesReceived, a.BytesReceived()) - assert.Equal(t, conn.bytesSent, a.BytesSent()) - assert.Equal(t, conn.mtu, a.MTU()) - assert.Equal(t, conn.cwnd, a.CWND()) - assert.Equal(t, conn.rwnd, a.RWND()) - assert.Equal(t, conn.srtt, a.SRTT()) + assert.Equal(t, conn.bytesReceived, assoc.BytesReceived()) + assert.Equal(t, conn.bytesSent, assoc.BytesSent()) + assert.Equal(t, conn.mtu, assoc.MTU()) + assert.Equal(t, conn.cwnd, assoc.CWND()) + assert.Equal(t, conn.rwnd, assoc.RWND()) + assert.Equal(t, conn.srtt, assoc.SRTT()) } func TestAssocHandleInit(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() handleInitTest := func(t *testing.T, initialState uint32, expectErr bool) { - a := createAssociation(Config{ + t.Helper() + + assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) - a.setState(initialState) + assoc.setState(initialState) pkt := &packet{ sourcePort: 5001, destinationPort: 5002, @@ -2602,19 +2633,20 @@ func TestAssocHandleInit(t *testing.T) { init.advertisedReceiverWindowCredit = 512 * 1024 setSupportedExtensions(&init.chunkInitCommon) - _, err := a.handleInit(pkt, init) + _, err := assoc.handleInit(pkt, init) if expectErr { assert.Error(t, err, "should fail") + return } assert.NoError(t, err, "should succeed") - assert.Equal(t, init.initialTSN-1, a.peerLastTSN(), "should match") - assert.Equal(t, uint16(1001), a.myMaxNumOutboundStreams, "should match") - assert.Equal(t, uint16(1002), a.myMaxNumInboundStreams, "should match") - assert.Equal(t, uint32(5678), a.peerVerificationTag, "should match") - assert.Equal(t, pkt.sourcePort, a.destinationPort, "should match") - assert.Equal(t, pkt.destinationPort, a.sourcePort, "should match") - assert.True(t, a.useForwardTSN, "should be set to true") + assert.Equal(t, init.initialTSN-1, assoc.peerLastTSN(), "should match") + assert.Equal(t, uint16(1001), assoc.myMaxNumOutboundStreams, "should match") + assert.Equal(t, uint16(1002), assoc.myMaxNumInboundStreams, "should match") + assert.Equal(t, uint32(5678), assoc.peerVerificationTag, "should match") + assert.Equal(t, pkt.sourcePort, assoc.destinationPort, "should match") + assert.Equal(t, pkt.destinationPort, assoc.sourcePort, "should match") + assert.True(t, assoc.useForwardTSN, "should be set to true") } t.Run("normal", func(t *testing.T) { @@ -2716,10 +2748,11 @@ func newDumbConn2(localAddr, remoteAddr net.Addr) *dumbConn2 { remoteAddr: remoteAddr, } c.cond = sync.NewCond(&c.mutex) + return c } -// Implement the net.Conn interface methods +// Implement the net.Conn interface methods. func (c *dumbConn2) Read(b []byte) (n int, err error) { c.mutex.Lock() defer c.mutex.Unlock() @@ -2729,6 +2762,7 @@ func (c *dumbConn2) Read(b []byte) (n int, err error) { packet := c.packets[0] c.packets = c.packets[1:] n := copy(b, packet) + return n, nil } @@ -2749,6 +2783,7 @@ func (c *dumbConn2) Write(b []byte) (int, error) { return 0, &net.OpError{Op: "write", Net: "udp", Addr: c.remoteAddr, Err: net.ErrClosed} } c.remoteInboundHandler(b) + return len(b), nil } @@ -2758,6 +2793,7 @@ func (c *dumbConn2) Close() error { c.closed = true c.cond.Signal() + return nil } @@ -2781,7 +2817,7 @@ func (c *dumbConn2) inboundHandler(packet []byte) { } } -// crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other +// crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other. func createUDPConnPair() (net.Conn, net.Conn) { addr1 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} addr2 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} @@ -2789,10 +2825,11 @@ func createUDPConnPair() (net.Conn, net.Conn) { conn2 := newDumbConn2(addr2, addr1) conn1.remoteInboundHandler = conn2.inboundHandler conn2.remoteInboundHandler = conn1.inboundHandler + return conn1, conn2 } -func createAssocs() (*Association, *Association, error) { +func createAssocs() (*Association, *Association, error) { //nolint:cyclop udp1, udp2 := createUDPConnPair() loggerFactory := logging.NewDefaultLoggerFactory() @@ -2855,22 +2892,25 @@ loop: } } } + return a1, a2, nil } // udpDiscardReader blocks all reads after block is set to true. -// This allows us to send arbitrary packets on a stream and block the packets received in response +// This allows us to send arbitrary packets on a stream and block the packets received in response. type udpDiscardReader struct { net.Conn - ctx context.Context + ctx context.Context //nolint:containedctx block atomic.Bool } func (d *udpDiscardReader) Read(b []byte) (n int, err error) { if d.block.Load() { <-d.ctx.Done() + return 0, d.ctx.Err() } + return d.Conn.Read(b) } @@ -2878,7 +2918,12 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, return createAssociationPairWithConfig(udpConn1, udpConn2, Config{}) } -func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) { +//nolint:cyclop +func createAssociationPairWithConfig( + udpConn1 net.Conn, + udpConn2 net.Conn, + config Config, +) (*Association, *Association, error) { loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan interface{}) @@ -2942,6 +2987,7 @@ loop: } } } + return a1, a2, nil } @@ -2955,6 +3001,7 @@ func noErrorClose(t *testing.T, closeF func() error) { func readMyNextTSN(a *Association) uint32 { a.lock.Lock() defer a.lock.Unlock() + return a.myNextTSN } @@ -3017,6 +3064,7 @@ func TestAssociationReceiveWindow(t *testing.T) { bytesQueued := s2.getNumBytesInReassemblyQueue() if bytesQueued > 5_000_000 { t.Error("too many bytes enqueued with receive window of 10kb", bytesQueued) + break } t.Log("bytes queued", bytesQueued) @@ -3085,7 +3133,14 @@ func TestAssociationFastRtxWnd(t *testing.T) { require.NoError(t, err) } - require.Eventually(t, func() bool { return dropCounter.Load() >= 15 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load()) + require.Eventually( + t, + func() bool { return dropCounter.Load() >= 15 }, + 5*time.Second, + 10*time.Millisecond, + "drop %d", + dropCounter.Load(), + ) require.Zero(t, a1.stats.getNumFastRetrans()) require.False(t, a1.inFastRecovery) @@ -3094,7 +3149,7 @@ func TestAssociationFastRtxWnd(t *testing.T) { ack := *(lastSACK.Load()) ack.gapAckBlocks = []gapAckBlock{{start: 11}} for i := 11; i < 14; i++ { - ack.gapAckBlocks[0].end = uint16(i) + ack.gapAckBlocks[0].end = uint16(i) //nolint:gosec // G115 pkt := a1.createPacket([]chunk{&ack}) pktBuf, err1 := pkt.marshal(true) require.NoError(t, err1) @@ -3104,6 +3159,7 @@ func TestAssociationFastRtxWnd(t *testing.T) { require.Eventually(t, func() bool { a1.lock.RLock() defer a1.lock.RUnlock() + return a1.inFastRecovery }, 5*time.Second, 10*time.Millisecond) require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) @@ -3123,6 +3179,7 @@ func TestAssociationFastRtxWnd(t *testing.T) { // sack with cumAckPoint advanced, lastTSN should not be marked as missing ack.cumulativeTSNAck++ end := lastTSN - 1 - ack.cumulativeTSNAck + //nolint:gosec // G115 ack.gapAckBlocks = append(ack.gapAckBlocks, gapAckBlock{start: uint16(end), end: uint16(end)}) pkt := a1.createPacket([]chunk{&ack}) pktBuf, err := pkt.marshal(true) @@ -3131,6 +3188,7 @@ func TestAssociationFastRtxWnd(t *testing.T) { require.Eventually(t, func() bool { a1.lock.Lock() defer a1.lock.Unlock() + return lastChunkMinusTwo.missIndicator == 1 && lastChunk.missIndicator == 0 }, 5*time.Second, 10*time.Millisecond) } @@ -3166,11 +3224,13 @@ func TestAssociationMaxTSNOffset(t *testing.T) { raw, err := p.marshal(true) if err != nil { t.Fatal(err) + return } _, err = a1.netConn.Write(raw) if err != nil { t.Fatal(err) + return } } @@ -3413,20 +3473,20 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { aConn, charlieConn := pipeDump() - a := createAssociation(Config{ + assoc := createAssociation(Config{ NetConn: aConn, MaxReceiveBufferSize: 0, LoggerFactory: loggerFactory, }) - a.init(true) + assoc.init(true) if !testCase.skipClose { defer func() { - assert.NoError(t, a.close()) + assert.NoError(t, assoc.close()) }() } - packet, err := a.marshalPacket(testCase.inputPacket) + packet, err := assoc.marshalPacket(testCase.inputPacket) assert.NoError(t, err) _, err = charlieConn.Write(packet) assert.NoError(t, err) @@ -3691,9 +3751,9 @@ func TestAssociation_ReconfigRequestsLimited(t *testing.T) { for i := 0; i < maxReconfigRequests+100; i++ { c := &chunkReconfig{ paramA: ¶mOutgoingResetRequest{ - reconfigRequestSequenceNumber: 10 + uint32(i), - senderLastTSN: tsn + 10, // has to be enqueued - streamIdentifiers: []uint16{uint16(i)}, + reconfigRequestSequenceNumber: 10 + uint32(i), //nolint:gosec // G115 + senderLastTSN: tsn + 10, // has to be enqueued + streamIdentifiers: []uint16{uint16(i)}, //nolint:gosec // G115 }, } p := a1.createPacket([]chunk{c}) diff --git a/chunk_abort.go b/chunk_abort.go index eaed6978..2bf15dc0 100644 --- a/chunk_abort.go +++ b/chunk_abort.go @@ -34,7 +34,7 @@ type chunkAbort struct { errorCauses []errorCause } -// Abort chunk errors +// Abort chunk errors. var ( ErrChunkTypeNotAbort = errors.New("ChunkType is not of type ABORT") ErrBuildAbortChunkFailed = errors.New("failed build Abort Chunk") @@ -63,6 +63,7 @@ func (a *chunkAbort) unmarshal(raw []byte) error { offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } + return nil } @@ -77,6 +78,7 @@ func (a *chunkAbort) marshal() ([]byte, error) { } a.raw = append(a.raw, raw...) } + return a.chunkHeader.marshal() } @@ -84,7 +86,7 @@ func (a *chunkAbort) check() (abort bool, err error) { return false, nil } -// String makes chunkAbort printable +// String makes chunkAbort printable. func (a *chunkAbort) String() string { res := a.chunkHeader.String() diff --git a/chunk_cookie_ack.go b/chunk_cookie_ack.go index 814ed2ca..84567496 100644 --- a/chunk_cookie_ack.go +++ b/chunk_cookie_ack.go @@ -21,7 +21,7 @@ type chunkCookieAck struct { chunkHeader } -// Cookie ack chunk errors +// Cookie ack chunk errors. var ( ErrChunkTypeNotCookieAck = errors.New("ChunkType is not of type COOKIEACK") ) @@ -40,6 +40,7 @@ func (c *chunkCookieAck) unmarshal(raw []byte) error { func (c *chunkCookieAck) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieAck + return c.chunkHeader.marshal() } @@ -47,7 +48,7 @@ func (c *chunkCookieAck) check() (abort bool, err error) { return false, nil } -// String makes chunkCookieAck printable +// String makes chunkCookieAck printable. func (c *chunkCookieAck) String() string { return c.chunkHeader.String() } diff --git a/chunk_cookie_echo.go b/chunk_cookie_echo.go index 23a2fc26..bf7d361b 100644 --- a/chunk_cookie_echo.go +++ b/chunk_cookie_echo.go @@ -25,7 +25,7 @@ type chunkCookieEcho struct { cookie []byte } -// Cookie echo chunk errors +// Cookie echo chunk errors. var ( ErrChunkTypeNotCookieEcho = errors.New("ChunkType is not of type COOKIEECHO") ) @@ -46,6 +46,7 @@ func (c *chunkCookieEcho) unmarshal(raw []byte) error { func (c *chunkCookieEcho) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieEcho c.chunkHeader.raw = c.cookie + return c.chunkHeader.marshal() } diff --git a/chunk_error.go b/chunk_error.go index b62be3ae..89328460 100644 --- a/chunk_error.go +++ b/chunk_error.go @@ -41,7 +41,7 @@ type chunkError struct { errorCauses []errorCause } -// Error chunk errors +// Error chunk errors. var ( ErrChunkTypeNotCtError = errors.New("ChunkType is not of type ctError") ErrBuildErrorChunkFailed = errors.New("failed build Error Chunk") @@ -70,6 +70,7 @@ func (a *chunkError) unmarshal(raw []byte) error { offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } + return nil } @@ -84,6 +85,7 @@ func (a *chunkError) marshal() ([]byte, error) { } a.raw = append(a.raw, raw...) } + return a.chunkHeader.marshal() } @@ -91,7 +93,7 @@ func (a *chunkError) check() (abort bool, err error) { return false, nil } -// String makes chunkError printable +// String makes chunkError printable. func (a *chunkError) String() string { res := a.chunkHeader.String() diff --git a/chunk_forward_tsn.go b/chunk_forward_tsn.go index b053e98e..e342e658 100644 --- a/chunk_forward_tsn.go +++ b/chunk_forward_tsn.go @@ -46,7 +46,7 @@ const ( forwardTSNStreamLength = 4 ) -// Forward TSN chunk errors +// Forward TSN chunk errors. var ( ErrMarshalStreamFailed = errors.New("failed to marshal stream") ErrChunkTooShort = errors.New("chunk too short") @@ -90,11 +90,12 @@ func (c *chunkForwardTSN) marshal() ([]byte, error) { if err != nil { return nil, fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) //nolint:errorlint } - out = append(out, b...) + out = append(out, b...) //nolint:makezero // TODO: fix } c.typ = ctForwardTSN c.raw = out + return c.chunkHeader.marshal() } @@ -102,12 +103,13 @@ func (c *chunkForwardTSN) check() (abort bool, err error) { return true, nil } -// String makes chunkForwardTSN printable +// String makes chunkForwardTSN printable. func (c *chunkForwardTSN) String() string { res := fmt.Sprintf("New Cumulative TSN: %d\n", c.newCumulativeTSN) for _, s := range c.streams { res += fmt.Sprintf(" - si=%d, ssn=%d\n", s.identifier, s.sequence) } + return res } diff --git a/chunk_heartbeat.go b/chunk_heartbeat.go index 64914a36..341e5442 100644 --- a/chunk_heartbeat.go +++ b/chunk_heartbeat.go @@ -34,13 +34,14 @@ in Section 3.2.1, i.e.: Variable Parameters Status Type Value ------------------------------------------------------------- heartbeat Info Mandatory 1 +. */ type chunkHeartbeat struct { chunkHeader params []param } -// Heartbeat chunk errors +// Heartbeat chunk errors. var ( ErrChunkTypeNotHeartbeat = errors.New("ChunkType is not of type HEARTBEAT") ErrHeartbeatNotLongEnoughInfo = errors.New("heartbeat is not long enough to contain Heartbeat Info") diff --git a/chunk_heartbeat_ack.go b/chunk_heartbeat_ack.go index 6d825e00..8ccc1531 100644 --- a/chunk_heartbeat_ack.go +++ b/chunk_heartbeat_ack.go @@ -34,13 +34,14 @@ in Section 3.2.1, i.e.: Variable Parameters Status Type Value ------------------------------------------------------------- Heartbeat Info Mandatory 1 +. */ type chunkHeartbeatAck struct { chunkHeader params []param } -// Heartbeat ack chunk errors +// Heartbeat ack chunk errors. var ( ErrUnimplemented = errors.New("unimplemented") ErrHeartbeatAckParams = errors.New("heartbeat Ack must have one param") diff --git a/chunk_init.go b/chunk_init.go index a6a3be31..b31aa973 100644 --- a/chunk_init.go +++ b/chunk_init.go @@ -27,7 +27,7 @@ type chunkInit struct { chunkInitCommon } -// Init chunk errors +// Init chunk errors. var ( ErrChunkTypeNotTypeInit = errors.New("ChunkType is not of type INIT") ErrChunkValueNotLongEnough = errors.New("chunk Value isn't long enough for mandatory parameters exp") @@ -74,6 +74,7 @@ func (i *chunkInit) marshal() ([]byte, error) { i.chunkHeader.typ = ctInit i.chunkHeader.raw = initShared + return i.chunkHeader.marshal() } @@ -136,7 +137,7 @@ func (i *chunkInit) check() (abort bool, err error) { return false, nil } -// String makes chunkInit printable +// String makes chunkInit printable. func (i *chunkInit) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } diff --git a/chunk_init_ack.go b/chunk_init_ack.go index ab790bd8..abd65c8c 100644 --- a/chunk_init_ack.go +++ b/chunk_init_ack.go @@ -27,7 +27,7 @@ type chunkInitAck struct { chunkInitCommon } -// Init ack chunk errors +// Init ack chunk errors. var ( ErrChunkTypeNotInitAck = errors.New("ChunkType is not of type INIT ACK") ErrChunkNotLongEnoughForParams = errors.New("chunk Value isn't long enough for mandatory parameters exp") @@ -73,6 +73,7 @@ func (i *chunkInitAck) marshal() ([]byte, error) { i.chunkHeader.typ = ctInitAck i.chunkHeader.raw = initShared + return i.chunkHeader.marshal() } @@ -91,6 +92,7 @@ func (i *chunkInitAck) check() (abort bool, err error) { // purpose. if i.initiateTag == 0 { abort = true + return abort, ErrChunkTypeInitAckInitateTagZero } @@ -106,6 +108,7 @@ func (i *chunkInitAck) check() (abort bool, err error) { // destroy the association discarding its TCB. if i.numInboundStreams == 0 { abort = true + return abort, ErrInitAckInboundStreamRequestZero } @@ -119,6 +122,7 @@ func (i *chunkInitAck) check() (abort bool, err error) { if i.numOutboundStreams == 0 { abort = true + return abort, ErrInitAckOutboundStreamRequestZero } @@ -128,13 +132,14 @@ func (i *chunkInitAck) check() (abort bool, err error) { // ACK. if i.advertisedReceiverWindowCredit < 1500 { abort = true + return abort, ErrInitAckAdvertisedReceiver1500 } return false, nil } -// String makes chunkInitAck printable +// String makes chunkInitAck printable. func (i *chunkInitAck) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } diff --git a/chunk_init_common.go b/chunk_init_common.go index b3f8b845..5c893a7d 100644 --- a/chunk_init_common.go +++ b/chunk_init_common.go @@ -57,7 +57,7 @@ const ( initOptionalVarHeaderLength = 4 ) -// Init chunk errors +// Init chunk errors. var ( ErrInitChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrInitAckMarshalParam = errors.New("unable to marshal parameter for INIT/INITACK") @@ -127,7 +127,7 @@ func (i *chunkInitCommon) marshal() ([]byte, error) { return nil, fmt.Errorf("%w: %v", ErrInitAckMarshalParam, err) //nolint:errorlint } - out = append(out, pp...) + out = append(out, pp...) //nolint:makezero // TODO: fix // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes @@ -144,7 +144,7 @@ func (i *chunkInitCommon) marshal() ([]byte, error) { return out, nil } -// String makes chunkInitCommon printable +// String makes chunkInitCommon printable. func (i chunkInitCommon) String() string { format := `initiateTag: %d advertisedReceiverWindowCredit: %d @@ -163,5 +163,6 @@ func (i chunkInitCommon) String() string { for i, param := range i.params { res += fmt.Sprintf("Param %d:\n %s", i, param) } + return res } diff --git a/chunk_init_test.go b/chunk_init_test.go index e235e3fc..b4da1e90 100644 --- a/chunk_init_test.go +++ b/chunk_init_test.go @@ -16,20 +16,22 @@ func TestChunkInit_UnrecognizedParameters(t *testing.T) { unrecognizedSkip := append([]byte{}, initChunkHeader...) unrecognizedSkip = append(unrecognizedSkip, byte(paramHeaderUnrecognizedActionSkip), 0xFF, 0x00, 0x04, 0x00) - i := &chunkInitCommon{} - if err := i.unmarshal(unrecognizedSkip); err != nil { + initCommonChunk := &chunkInitCommon{} + if err := initCommonChunk.unmarshal(unrecognizedSkip); err != nil { t.Errorf("Unmarshal init Chunk failed: %v", err) - } else if len(i.unrecognizedParams) != 1 || i.unrecognizedParams[0].unrecognizedAction != paramHeaderUnrecognizedActionSkip { + } else if len(initCommonChunk.unrecognizedParams) != 1 || + initCommonChunk.unrecognizedParams[0].unrecognizedAction != paramHeaderUnrecognizedActionSkip { t.Errorf("Unrecognized Param parsed incorrectly") } unrecognizedStop := append([]byte{}, initChunkHeader...) unrecognizedStop = append(unrecognizedStop, byte(paramHeaderUnrecognizedActionStop), 0xFF, 0x00, 0x04, 0x00) - i = &chunkInitCommon{} - if err := i.unmarshal(unrecognizedStop); err != nil { + initCommonChunk = &chunkInitCommon{} + if err := initCommonChunk.unmarshal(unrecognizedStop); err != nil { t.Errorf("Unmarshal init Chunk failed: %v", err) - } else if len(i.unrecognizedParams) != 1 || i.unrecognizedParams[0].unrecognizedAction != paramHeaderUnrecognizedActionStop { + } else if len(initCommonChunk.unrecognizedParams) != 1 || + initCommonChunk.unrecognizedParams[0].unrecognizedAction != paramHeaderUnrecognizedActionStop { t.Errorf("Unrecognized Param parsed incorrectly") } } diff --git a/chunk_payload_data.go b/chunk_payload_data.go index a5a00064..1ff9b591 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -86,7 +86,7 @@ const ( payloadDataHeaderSize = 12 ) -// PayloadProtocolIdentifier is an enum for DataChannel payload types +// PayloadProtocolIdentifier is an enum for DataChannel payload types. type PayloadProtocolIdentifier uint32 // PayloadProtocolIdentifier enums @@ -100,7 +100,7 @@ const ( PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57 ) -// Data chunk errors +// Data chunk errors. var ( ErrChunkPayloadSmall = errors.New("packet is smaller than the header size") ) @@ -170,6 +170,7 @@ func (p *chunkPayloadData) marshal() ([]byte, error) { p.chunkHeader.flags = flags p.chunkHeader.typ = ctPayloadData p.chunkHeader.raw = payRaw + return p.chunkHeader.marshal() } @@ -177,7 +178,7 @@ func (p *chunkPayloadData) check() (abort bool, err error) { return false, nil } -// String makes chunkPayloadData printable +// String makes chunkPayloadData printable. func (p *chunkPayloadData) String() string { return fmt.Sprintf("%s\n%d", p.chunkHeader, p.tsn) } @@ -186,12 +187,14 @@ func (p *chunkPayloadData) abandoned() bool { if p.head != nil { return p.head._abandoned && p.head._allInflight } + return p._abandoned && p._allInflight } func (p *chunkPayloadData) setAbandoned(abandoned bool) { if p.head != nil { p.head._abandoned = abandoned + return } p._abandoned = abandoned diff --git a/chunk_reconfig.go b/chunk_reconfig.go index 39cb1fbd..8e47c2bb 100644 --- a/chunk_reconfig.go +++ b/chunk_reconfig.go @@ -31,7 +31,7 @@ type chunkReconfig struct { paramB param } -// Reconfigure chunk errors +// Reconfigure chunk errors. var ( ErrChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrChunkMarshalParamAReconfigFailed = errors.New("unable to marshal parameter A for reconfig") @@ -88,6 +88,7 @@ func (c *chunkReconfig) marshal() ([]byte, error) { c.typ = ctReconfig c.raw = out + return c.chunkHeader.marshal() } @@ -98,11 +99,12 @@ func (c *chunkReconfig) check() (abort bool, err error) { return true, nil } -// String makes chunkReconfig printable +// String makes chunkReconfig printable. func (c *chunkReconfig) String() string { res := fmt.Sprintf("Param A:\n %s", c.paramA) if c.paramB != nil { res += fmt.Sprintf("Param B:\n %s", c.paramB) } + return res } diff --git a/chunk_reconfig_test.go b/chunk_reconfig_test.go index 9b1323ff..43d5dcaa 100644 --- a/chunk_reconfig_test.go +++ b/chunk_reconfig_test.go @@ -13,11 +13,27 @@ func TestChunkReconfig_Success(t *testing.T) { tt := []struct { binary []byte }{ - {append([]byte{0x82, 0x0, 0x0, 0x1a}, testChunkReconfigParamA()...)}, // Note: chunk trailing padding is added in packet.marshal + { + // Note: chunk trailing padding is added in packet.marshal + append( + []byte{0x82, 0x0, 0x0, 0x1a}, testChunkReconfigParamA()..., + ), + }, {append([]byte{0x82, 0x0, 0x0, 0x14}, testChunkReconfigParamB()...)}, {append([]byte{0x82, 0x0, 0x0, 0x10}, testChunkReconfigResponce()...)}, - {append(append([]byte{0x82, 0x0, 0x0, 0x2c}, padByte(testChunkReconfigParamA(), 2)...), testChunkReconfigParamB()...)}, - {append(append([]byte{0x82, 0x0, 0x0, 0x2a}, testChunkReconfigParamB()...), testChunkReconfigParamA()...)}, // Note: chunk trailing padding is added in packet.marshal + { + append( + append([]byte{0x82, 0x0, 0x0, 0x2c}, padByte(testChunkReconfigParamA(), 2)...), + testChunkReconfigParamB()...), + }, + { + // Note: chunk trailing padding is added in packet.marshal + append( + append([]byte{0x82, 0x0, 0x0, 0x2a}, + testChunkReconfigParamB()...), + testChunkReconfigParamA()..., + ), + }, } for i, tc := range tt { @@ -43,7 +59,12 @@ func TestChunkReconfigUnmarshal_Failure(t *testing.T) { {"chunk header to short", []byte{0x82}}, {"missing parse param type (A)", []byte{0x82, 0x0, 0x0, 0x4}}, {"wrong param (A)", []byte{0x82, 0x0, 0x0, 0x8, 0x0, 0xd, 0x0, 0x0}}, - {"wrong param (B)", append(append([]byte{0x82, 0x0, 0x0, 0x18}, testChunkReconfigParamB()...), []byte{0x0, 0xd, 0x0, 0x0}...)}, + { + "wrong param (B)", + append(append([]byte{0x82, 0x0, 0x0, 0x18}, + testChunkReconfigParamB()...), + []byte{0x0, 0xd, 0x0, 0x0}...), + }, } for i, tc := range tt { diff --git a/chunk_selective_ack.go b/chunk_selective_ack.go index 0d60b6a2..0fed910f 100644 --- a/chunk_selective_ack.go +++ b/chunk_selective_ack.go @@ -49,14 +49,14 @@ type gapAckBlock struct { end uint16 } -// Selective ack chunk errors +// Selective ack chunk errors. var ( ErrChunkTypeNotSack = errors.New("ChunkType is not of type SACK") ErrSackSizeNotLargeEnoughInfo = errors.New("SACK Chunk size is not large enough to contain header") ErrSackSizeNotMatchPredicted = errors.New("SACK Chunk size does not match predicted amount from header values") ) -// String makes gapAckBlock printable +// String makes gapAckBlock printable. func (g gapAckBlock) String() string { return fmt.Sprintf("%d - %d", g.start, g.end) } @@ -114,8 +114,8 @@ func (s *chunkSelectiveAck) marshal() ([]byte, error) { sackRaw := make([]byte, selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN)))) binary.BigEndian.PutUint32(sackRaw[0:], s.cumulativeTSNAck) binary.BigEndian.PutUint32(sackRaw[4:], s.advertisedReceiverWindowCredit) - binary.BigEndian.PutUint16(sackRaw[8:], uint16(len(s.gapAckBlocks))) - binary.BigEndian.PutUint16(sackRaw[10:], uint16(len(s.duplicateTSN))) + binary.BigEndian.PutUint16(sackRaw[8:], uint16(len(s.gapAckBlocks))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(sackRaw[10:], uint16(len(s.duplicateTSN))) //nolint:gosec // G115 offset := selectiveAckHeaderSize for _, g := range s.gapAckBlocks { binary.BigEndian.PutUint16(sackRaw[offset:], g.start) @@ -129,6 +129,7 @@ func (s *chunkSelectiveAck) marshal() ([]byte, error) { s.chunkHeader.typ = ctSack s.chunkHeader.raw = sackRaw + return s.chunkHeader.marshal() } @@ -136,7 +137,7 @@ func (s *chunkSelectiveAck) check() (abort bool, err error) { return false, nil } -// String makes chunkSelectiveAck printable +// String makes chunkSelectiveAck printable. func (s *chunkSelectiveAck) String() string { res := fmt.Sprintf("SACK cumTsnAck=%d arwnd=%d dupTsn=%d", s.cumulativeTSNAck, diff --git a/chunk_shutdown.go b/chunk_shutdown.go index 9cc756c1..7259977a 100644 --- a/chunk_shutdown.go +++ b/chunk_shutdown.go @@ -18,7 +18,7 @@ chunkShutdown represents an SCTP Chunk of type chunkShutdown | Type = 7 | Chunk Flags | Length = 8 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cumulative TSN Ack | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdown struct { chunkHeader @@ -29,7 +29,7 @@ const ( cumulativeTSNAckLength = 4 ) -// Shutdown chunk errors +// Shutdown chunk errors. var ( ErrInvalidChunkSize = errors.New("invalid chunk size") ErrChunkTypeNotShutdown = errors.New("ChunkType is not of type SHUTDOWN") @@ -59,6 +59,7 @@ func (c *chunkShutdown) marshal() ([]byte, error) { c.typ = ctShutdown c.raw = out + return c.chunkHeader.marshal() } @@ -66,7 +67,7 @@ func (c *chunkShutdown) check() (abort bool, err error) { return false, nil } -// String makes chunkShutdown printable +// String makes chunkShutdown printable. func (c *chunkShutdown) String() string { return c.chunkHeader.String() } diff --git a/chunk_shutdown_ack.go b/chunk_shutdown_ack.go index 87f41700..33881003 100644 --- a/chunk_shutdown_ack.go +++ b/chunk_shutdown_ack.go @@ -15,13 +15,13 @@ chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 8 | Chunk Flags | Length = 4 | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdownAck struct { chunkHeader } -// Shutdown ack chunk errors +// Shutdown ack chunk errors. var ( ErrChunkTypeNotShutdownAck = errors.New("ChunkType is not of type SHUTDOWN-ACK") ) @@ -40,6 +40,7 @@ func (c *chunkShutdownAck) unmarshal(raw []byte) error { func (c *chunkShutdownAck) marshal() ([]byte, error) { c.typ = ctShutdownAck + return c.chunkHeader.marshal() } @@ -47,7 +48,7 @@ func (c *chunkShutdownAck) check() (abort bool, err error) { return false, nil } -// String makes chunkShutdownAck printable +// String makes chunkShutdownAck printable. func (c *chunkShutdownAck) String() string { return c.chunkHeader.String() } diff --git a/chunk_shutdown_complete.go b/chunk_shutdown_complete.go index 9652f887..dfd7dd6b 100644 --- a/chunk_shutdown_complete.go +++ b/chunk_shutdown_complete.go @@ -15,13 +15,13 @@ chunkShutdownComplete represents an SCTP Chunk of type chunkShutdownComplete 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 14 |Reserved |T| Length = 4 | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdownComplete struct { chunkHeader } -// Shutdown complete chunk errors +// Shutdown complete chunk errors. var ( ErrChunkTypeNotShutdownComplete = errors.New("ChunkType is not of type SHUTDOWN-COMPLETE") ) @@ -40,6 +40,7 @@ func (c *chunkShutdownComplete) unmarshal(raw []byte) error { func (c *chunkShutdownComplete) marshal() ([]byte, error) { c.typ = ctShutdownComplete + return c.chunkHeader.marshal() } @@ -47,7 +48,7 @@ func (c *chunkShutdownComplete) check() (abort bool, err error) { return false, nil } -// String makes chunkShutdownComplete printable +// String makes chunkShutdownComplete printable. func (c *chunkShutdownComplete) String() string { return c.chunkHeader.String() } diff --git a/chunk_test.go b/chunk_test.go index c831e447..ddb6cf23 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -12,18 +12,19 @@ import ( func TestInitChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{ - 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, - 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, - 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, - 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, - 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, 0xb9, + 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, + 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, + 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, + 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk") } - i, ok := pkt.chunks[0].(*chunkInit) + initChunk, ok := pkt.chunks[0].(*chunkInit) if !ok { t.Errorf("Failed to cast Chunk -> Init") } @@ -31,22 +32,41 @@ func TestInitChunk(t *testing.T) { switch { case err != nil: t.Errorf("Unmarshal init Chunk failed: %v", err) - case i.initiateTag != 1438213285: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 1438213285, i.initiateTag) - case i.advertisedReceiverWindowCredit != 131072: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 131072, i.advertisedReceiverWindowCredit) - case i.numOutboundStreams != 1024: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1024, i.numOutboundStreams) - case i.numInboundStreams != 2048: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 2048, i.numInboundStreams) - case i.initialTSN != uint32(3899461680): - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", uint32(3899461680), i.initialTSN) + case initChunk.initiateTag != 1438213285: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", + 1438213285, initChunk.initiateTag, + ) + case initChunk.advertisedReceiverWindowCredit != 131072: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", + 131072, initChunk.advertisedReceiverWindowCredit, + ) + case initChunk.numOutboundStreams != 1024: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", + 1024, initChunk.numOutboundStreams, + ) + case initChunk.numInboundStreams != 2048: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", + 2048, initChunk.numInboundStreams, + ) + case initChunk.initialTSN != uint32(3899461680): + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", + uint32(3899461680), initChunk.initialTSN, + ) } } func TestInitAck(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, 0x1c, 0xeb, 0x81, 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, 0x94, 0x06, 0x2f, 0x93} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, 0x1c, 0xeb, 0x81, + 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, + 0x94, 0x06, 0x2f, 0x93, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -62,7 +82,14 @@ func TestInitAck(t *testing.T) { func TestChromeChunk1Init(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, + 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, + 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, + 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, + 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -78,7 +105,31 @@ func TestChromeChunk1Init(t *testing.T) { func TestChromeChunk2InitAck(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, + 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, + 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, + 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, + 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, 0x41, 0x4d, 0x45, + 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, + 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, + 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, + 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, + 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, + 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, + 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, + 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, + 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, + 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, + 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, + 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, + 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, + 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -92,11 +143,11 @@ func TestChromeChunk2InitAck(t *testing.T) { assert.Equal(t, rawPkt, rawPkt2) } -func TestInitMarshalUnmarshal(t *testing.T) { - p := &packet{} - p.destinationPort = 1 - p.sourcePort = 1 - p.verificationTag = 123 +func TestInitMarshalUnmarshal(t *testing.T) { //nolint:cyclop + sctpPacket := &packet{} + sctpPacket.destinationPort = 1 + sctpPacket.sourcePort = 1 + sctpPacket.verificationTag = 123 initAck := &chunkInitAck{} @@ -111,8 +162,8 @@ func TestInitMarshalUnmarshal(t *testing.T) { } initAck.params = []param{cookie} - p.chunks = []chunk{initAck} - rawPkt, err := p.marshal(true) + sctpPacket.chunks = []chunk{initAck} + rawPkt, err := sctpPacket.marshal(true) if err != nil { t.Errorf("Failed to marshal packet: %v", err) } @@ -123,7 +174,7 @@ func TestInitMarshalUnmarshal(t *testing.T) { t.Errorf("Unmarshal failed, has chunk: %v", err) } - i, ok := pkt.chunks[0].(*chunkInitAck) + initAckChunk, ok := pkt.chunks[0].(*chunkInitAck) if !ok { t.Error("Failed to cast Chunk -> InitAck") } @@ -131,22 +182,43 @@ func TestInitMarshalUnmarshal(t *testing.T) { switch { case err != nil: t.Errorf("Unmarshal init ack Chunk failed: %v", err) - case i.initiateTag != 123: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 123, i.initiateTag) - case i.advertisedReceiverWindowCredit != 1024: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 1024, i.advertisedReceiverWindowCredit) - case i.numOutboundStreams != 1: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1, i.numOutboundStreams) - case i.numInboundStreams != 1: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 1, i.numInboundStreams) - case i.initialTSN != 123: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", 123, i.initialTSN) + case initAckChunk.initiateTag != 123: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", + 123, initAckChunk.initiateTag, + ) + case initAckChunk.advertisedReceiverWindowCredit != 1024: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", + 1024, initAckChunk.advertisedReceiverWindowCredit, + ) + case initAckChunk.numOutboundStreams != 1: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", + 1, initAckChunk.numOutboundStreams, + ) + case initAckChunk.numInboundStreams != 1: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", + 1, initAckChunk.numInboundStreams, + ) + case initAckChunk.initialTSN != 123: + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", + 123, initAckChunk.initialTSN, + ) } } func TestPayloadDataMarshalUnmarshal(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, 0x24, 0x9b, 0x28, 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x00} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, 0x24, 0x9b, 0x28, + 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, + 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x00, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -160,7 +232,11 @@ func TestPayloadDataMarshalUnmarshal(t *testing.T) { func TestSelectAckChunk(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, 0x78, 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, + 0x78, 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, + 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -174,7 +250,11 @@ func TestSelectAckChunk(t *testing.T) { func TestReconfigChunk(t *testing.T) { pkt := &packet{} - rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, 0x0, 0x0, 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, 0xff, 0x4e, 0x1c, 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0} + rawPkt := []byte{ + 0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, + 0x0, 0x0, 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, + 0xff, 0x4e, 0x1c, 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0, + } err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) @@ -186,13 +266,19 @@ func TestReconfigChunk(t *testing.T) { } if c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0] != uint16(1) { //nolint:forcetypeassert - t.Errorf("unexpected stream identifier: %d", c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0]) //nolint:forcetypeassert + t.Errorf( + "unexpected stream identifier: %d", + c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0], //nolint:forcetypeassert + ) } } func TestForwardTSNChunk(t *testing.T) { pkt := &packet{} - rawPkt := append([]byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb}, testChunkForwardTSN()...) + rawPkt := append( + []byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb}, + testChunkForwardTSN()..., + ) err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) diff --git a/chunkheader.go b/chunkheader.go index 60b903b4..8c5c3129 100644 --- a/chunkheader.go +++ b/chunkheader.go @@ -36,7 +36,7 @@ const ( chunkHeaderSize = 4 ) -// SCTP chunk header errors +// SCTP chunk header errors. var ( ErrChunkHeaderTooSmall = errors.New("raw is too small for a SCTP chunk") ErrChunkHeaderNotEnoughSpace = errors.New("not enough data left in SCTP packet to satisfy requested length") @@ -45,7 +45,10 @@ var ( func (c *chunkHeader) unmarshal(raw []byte) error { if len(raw) < chunkHeaderSize { - return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrChunkHeaderTooSmall, len(raw), chunkHeaderSize) + return fmt.Errorf( + "%w: raw only %d bytes, %d is the minimum length", + ErrChunkHeaderTooSmall, len(raw), chunkHeaderSize, + ) } c.typ = chunkType(raw[0]) @@ -77,6 +80,7 @@ func (c *chunkHeader) unmarshal(raw []byte) error { } c.raw = raw[chunkHeaderSize : chunkHeaderSize+valueLength] + return nil } @@ -85,8 +89,9 @@ func (c *chunkHeader) marshal() ([]byte, error) { raw[0] = uint8(c.typ) raw[1] = c.flags - binary.BigEndian.PutUint16(raw[2:], uint16(len(c.raw)+chunkHeaderSize)) + binary.BigEndian.PutUint16(raw[2:], uint16(len(c.raw)+chunkHeaderSize)) //nolint:gosec // G115 copy(raw[4:], c.raw) + return raw, nil } @@ -94,7 +99,7 @@ func (c *chunkHeader) valueLength() int { return len(c.raw) } -// String makes chunkHeader printable +// String makes chunkHeader printable. func (c chunkHeader) String() string { return c.typ.String() } diff --git a/chunktype.go b/chunktype.go index ed7e60cc..92133f29 100644 --- a/chunktype.go +++ b/chunktype.go @@ -10,7 +10,7 @@ import "fmt" // Chunk Value field. type chunkType uint8 -// List of known chunkType enums +// List of known chunkType enums. const ( ctPayloadData chunkType = 0 ctInit chunkType = 1 @@ -30,7 +30,7 @@ const ( ctForwardTSN chunkType = 192 ) -func (c chunkType) String() string { +func (c chunkType) String() string { //nolint:cyclop switch c { case ctPayloadData: return "DATA" diff --git a/control_queue.go b/control_queue.go index 5c417bf0..20ae995d 100644 --- a/control_queue.go +++ b/control_queue.go @@ -24,6 +24,7 @@ func (q *controlQueue) pushAll(packets []*packet) { func (q *controlQueue) popAll() []*packet { packets := q.queue q.queue = []*packet{} + return packets } diff --git a/error_cause.go b/error_cause.go index bb1c7c8c..f3cf72f5 100644 --- a/error_cause.go +++ b/error_cause.go @@ -9,7 +9,7 @@ import ( "fmt" ) -// errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk +// errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk. type errorCauseCode uint16 type errorCause interface { @@ -21,34 +21,34 @@ type errorCause interface { errorCauseCode() errorCauseCode } -// Error and abort chunk errors +// Error and abort chunk errors. var ( ErrBuildErrorCaseHandle = errors.New("BuildErrorCause does not handle") ) -// buildErrorCause delegates the building of a error cause from raw bytes to the correct structure +// buildErrorCause delegates the building of a error cause from raw bytes to the correct structure. func buildErrorCause(raw []byte) (errorCause, error) { - var e errorCause + var errCause errorCause c := errorCauseCode(binary.BigEndian.Uint16(raw[0:])) switch c { case invalidMandatoryParameter: - e = &errorCauseInvalidMandatoryParameter{} + errCause = &errorCauseInvalidMandatoryParameter{} case unrecognizedChunkType: - e = &errorCauseUnrecognizedChunkType{} + errCause = &errorCauseUnrecognizedChunkType{} case protocolViolation: - e = &errorCauseProtocolViolation{} + errCause = &errorCauseProtocolViolation{} case userInitiatedAbort: - e = &errorCauseUserInitiatedAbort{} + errCause = &errorCauseUserInitiatedAbort{} default: return nil, fmt.Errorf("%w: %s", ErrBuildErrorCaseHandle, c.String()) } - if err := e.unmarshal(raw); err != nil { + if err := errCause.unmarshal(raw); err != nil { return nil, err } - return e, nil + return errCause, nil } const ( @@ -67,7 +67,7 @@ const ( protocolViolation errorCauseCode = 13 ) -func (e errorCauseCode) String() string { +func (e errorCauseCode) String() string { //nolint:cyclop switch e { case invalidStreamIdentifier: return "Invalid Stream Identifier" diff --git a/error_cause_header.go b/error_cause_header.go index 98c530fb..fd1cd261 100644 --- a/error_cause_header.go +++ b/error_cause_header.go @@ -8,7 +8,7 @@ import ( "errors" ) -// errorCauseHeader represents the shared header that is shared by all error causes +// errorCauseHeader represents the shared header that is shared by all error causes. type errorCauseHeader struct { code errorCauseCode len uint16 @@ -19,11 +19,11 @@ const ( errorCauseHeaderLength = 4 ) -// ErrInvalidSCTPChunk is returned when an SCTP chunk is invalid +// ErrInvalidSCTPChunk is returned when an SCTP chunk is invalid. var ErrInvalidSCTPChunk = errors.New("invalid SCTP chunk") func (e *errorCauseHeader) marshal() ([]byte, error) { - e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) + e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) //nolint:gosec // G115 raw := make([]byte, e.len) binary.BigEndian.PutUint16(raw[0:], uint16(e.code)) binary.BigEndian.PutUint16(raw[2:], e.len) @@ -40,6 +40,7 @@ func (e *errorCauseHeader) unmarshal(raw []byte) error { } valueLength := e.len - errorCauseHeaderLength e.raw = raw[errorCauseHeaderLength : errorCauseHeaderLength+valueLength] + return nil } @@ -51,7 +52,7 @@ func (e *errorCauseHeader) errorCauseCode() errorCauseCode { return e.code } -// String makes errorCauseHeader printable +// String makes errorCauseHeader printable. func (e errorCauseHeader) String() string { return e.code.String() } diff --git a/error_cause_invalid_mandatory_parameter.go b/error_cause_invalid_mandatory_parameter.go index e73774bf..e8615a4e 100644 --- a/error_cause_invalid_mandatory_parameter.go +++ b/error_cause_invalid_mandatory_parameter.go @@ -3,7 +3,7 @@ package sctp -// errorCauseInvalidMandatoryParameter represents an SCTP error cause +// errorCauseInvalidMandatoryParameter represents an SCTP error cause. type errorCauseInvalidMandatoryParameter struct { errorCauseHeader } @@ -16,7 +16,7 @@ func (e *errorCauseInvalidMandatoryParameter) unmarshal(raw []byte) error { return e.errorCauseHeader.unmarshal(raw) } -// String makes errorCauseInvalidMandatoryParameter printable +// String makes errorCauseInvalidMandatoryParameter printable. func (e *errorCauseInvalidMandatoryParameter) String() string { return e.errorCauseHeader.String() } diff --git a/error_cause_protocol_violation.go b/error_cause_protocol_violation.go index 5861aedd..4228ac36 100644 --- a/error_cause_protocol_violation.go +++ b/error_cause_protocol_violation.go @@ -30,13 +30,14 @@ type errorCauseProtocolViolation struct { additionalInformation []byte } -// Abort chunk errors +// Abort chunk errors. var ( ErrProtocolViolationUnmarshal = errors.New("unable to unmarshal Protocol Violation error") ) func (e *errorCauseProtocolViolation) marshal() ([]byte, error) { e.raw = e.additionalInformation + return e.errorCauseHeader.marshal() } @@ -51,7 +52,7 @@ func (e *errorCauseProtocolViolation) unmarshal(raw []byte) error { return nil } -// String makes errorCauseProtocolViolation printable +// String makes errorCauseProtocolViolation printable. func (e *errorCauseProtocolViolation) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader, e.additionalInformation) } diff --git a/error_cause_unrecognized_chunk_type.go b/error_cause_unrecognized_chunk_type.go index 84ea0e10..3a616ae8 100644 --- a/error_cause_unrecognized_chunk_type.go +++ b/error_cause_unrecognized_chunk_type.go @@ -3,7 +3,7 @@ package sctp -// errorCauseUnrecognizedChunkType represents an SCTP error cause +// errorCauseUnrecognizedChunkType represents an SCTP error cause. type errorCauseUnrecognizedChunkType struct { errorCauseHeader unrecognizedChunk []byte @@ -12,6 +12,7 @@ type errorCauseUnrecognizedChunkType struct { func (e *errorCauseUnrecognizedChunkType) marshal() ([]byte, error) { e.code = unrecognizedChunkType e.errorCauseHeader.raw = e.unrecognizedChunk + return e.errorCauseHeader.marshal() } @@ -22,10 +23,11 @@ func (e *errorCauseUnrecognizedChunkType) unmarshal(raw []byte) error { } e.unrecognizedChunk = e.errorCauseHeader.raw + return nil } -// String makes errorCauseUnrecognizedChunkType printable +// String makes errorCauseUnrecognizedChunkType printable. func (e *errorCauseUnrecognizedChunkType) String() string { return e.errorCauseHeader.String() } diff --git a/error_cause_user_initiated_abort.go b/error_cause_user_initiated_abort.go index 871460e2..57ffd14d 100644 --- a/error_cause_user_initiated_abort.go +++ b/error_cause_user_initiated_abort.go @@ -30,6 +30,7 @@ type errorCauseUserInitiatedAbort struct { func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) { e.code = userInitiatedAbort e.errorCauseHeader.raw = e.upperLayerAbortReason + return e.errorCauseHeader.marshal() } @@ -40,10 +41,11 @@ func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error { } e.upperLayerAbortReason = e.errorCauseHeader.raw + return nil } -// String makes errorCauseUserInitiatedAbort printable +// String makes errorCauseUserInitiatedAbort printable. func (e *errorCauseUserInitiatedAbort) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason) } diff --git a/examples/ping-pong/ping/conn.go b/examples/ping-pong/ping/conn.go index 3b0b448b..a1206f81 100644 --- a/examples/ping-pong/ping/conn.go +++ b/examples/ping-pong/ping/conn.go @@ -21,7 +21,7 @@ type disconnectedPacketConn struct { // nolint: unused pConn net.PacketConn } -// Read +// Read. func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { @@ -35,42 +35,44 @@ func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused return i, err } -// Write writes len(p) bytes from p to the DTLS connection +// Write writes len(p) bytes from p to the DTLS connection. func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) { //nolint:unused return c.pConn.WriteTo(p, c.RemoteAddr()) } -// Close closes the conn and releases any Read calls +// Close closes the conn and releases any Read calls. func (c *disconnectedPacketConn) Close() error { //nolint:unused return c.pConn.Close() } -// LocalAddr is a stub +// LocalAddr is a stub. func (c *disconnectedPacketConn) LocalAddr() net.Addr { //nolint:unused if c.pConn != nil { return c.pConn.LocalAddr() } + return nil } -// RemoteAddr is a stub +// RemoteAddr is a stub. func (c *disconnectedPacketConn) RemoteAddr() net.Addr { //nolint:unused c.mu.RLock() defer c.mu.RUnlock() + return c.rAddr } -// SetDeadline is a stub +// SetDeadline is a stub. func (c *disconnectedPacketConn) SetDeadline(time.Time) error { //nolint:unused return nil } -// SetReadDeadline is a stub +// SetReadDeadline is a stub. func (c *disconnectedPacketConn) SetReadDeadline(time.Time) error { //nolint:unused return nil } -// SetWriteDeadline is a stub +// SetWriteDeadline is a stub. func (c *disconnectedPacketConn) SetWriteDeadline(time.Time) error { //nolint:unused return nil } diff --git a/examples/ping-pong/ping/main.go b/examples/ping-pong/ping/main.go index 3ee6d293..9e3bf70e 100644 --- a/examples/ping-pong/ping/main.go +++ b/examples/ping-pong/ping/main.go @@ -16,7 +16,7 @@ import ( "github.com/pion/sctp" ) -func main() { +func main() { //nolint:cyclop conn, err := net.Dial("udp", "127.0.0.1:9899") if err != nil { log.Panic(err) diff --git a/examples/ping-pong/pong/conn.go b/examples/ping-pong/pong/conn.go index 3b0b448b..a1206f81 100644 --- a/examples/ping-pong/pong/conn.go +++ b/examples/ping-pong/pong/conn.go @@ -21,7 +21,7 @@ type disconnectedPacketConn struct { // nolint: unused pConn net.PacketConn } -// Read +// Read. func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { @@ -35,42 +35,44 @@ func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused return i, err } -// Write writes len(p) bytes from p to the DTLS connection +// Write writes len(p) bytes from p to the DTLS connection. func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) { //nolint:unused return c.pConn.WriteTo(p, c.RemoteAddr()) } -// Close closes the conn and releases any Read calls +// Close closes the conn and releases any Read calls. func (c *disconnectedPacketConn) Close() error { //nolint:unused return c.pConn.Close() } -// LocalAddr is a stub +// LocalAddr is a stub. func (c *disconnectedPacketConn) LocalAddr() net.Addr { //nolint:unused if c.pConn != nil { return c.pConn.LocalAddr() } + return nil } -// RemoteAddr is a stub +// RemoteAddr is a stub. func (c *disconnectedPacketConn) RemoteAddr() net.Addr { //nolint:unused c.mu.RLock() defer c.mu.RUnlock() + return c.rAddr } -// SetDeadline is a stub +// SetDeadline is a stub. func (c *disconnectedPacketConn) SetDeadline(time.Time) error { //nolint:unused return nil } -// SetReadDeadline is a stub +// SetReadDeadline is a stub. func (c *disconnectedPacketConn) SetReadDeadline(time.Time) error { //nolint:unused return nil } -// SetWriteDeadline is a stub +// SetWriteDeadline is a stub. func (c *disconnectedPacketConn) SetWriteDeadline(time.Time) error { //nolint:unused return nil } diff --git a/examples/ping-pong/pong/main.go b/examples/ping-pong/pong/main.go index 2f7085a3..ca05c02f 100644 --- a/examples/ping-pong/pong/main.go +++ b/examples/ping-pong/pong/main.go @@ -13,7 +13,7 @@ import ( "github.com/pion/sctp" ) -func main() { +func main() { //nolint:cyclop addr := net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 9899, diff --git a/packet.go b/packet.go index a2ab8e14..1e4ce529 100644 --- a/packet.go +++ b/packet.go @@ -10,7 +10,7 @@ import ( "hash/crc32" ) -// Create the crc32 table we'll use for the checksum +// Create the crc32 table we'll use for the checksum. var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals // Allocate and zero this data once. @@ -57,7 +57,7 @@ const ( packetHeaderSize = 12 ) -// SCTP packet errors +// SCTP packet errors. var ( ErrPacketRawTooSmall = errors.New("raw is smaller than the minimum length for a SCTP packet") ErrParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header") @@ -65,7 +65,7 @@ var ( ErrChecksumMismatch = errors.New("checksum mismatch theirs") ) -func (p *packet) unmarshal(doChecksum bool, raw []byte) error { +func (p *packet) unmarshal(doChecksum bool, raw []byte) error { //nolint:cyclop if len(raw) < packetHeaderSize { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) } @@ -102,47 +102,47 @@ func (p *packet) unmarshal(doChecksum bool, raw []byte) error { return fmt.Errorf("%w: offset %d remaining %d", ErrParseSCTPChunkNotEnoughData, offset, len(raw)) } - var c chunk + var dataChunk chunk switch chunkType(raw[offset]) { case ctInit: - c = &chunkInit{} + dataChunk = &chunkInit{} case ctInitAck: - c = &chunkInitAck{} + dataChunk = &chunkInitAck{} case ctAbort: - c = &chunkAbort{} + dataChunk = &chunkAbort{} case ctCookieEcho: - c = &chunkCookieEcho{} + dataChunk = &chunkCookieEcho{} case ctCookieAck: - c = &chunkCookieAck{} + dataChunk = &chunkCookieAck{} case ctHeartbeat: - c = &chunkHeartbeat{} + dataChunk = &chunkHeartbeat{} case ctPayloadData: - c = &chunkPayloadData{} + dataChunk = &chunkPayloadData{} case ctSack: - c = &chunkSelectiveAck{} + dataChunk = &chunkSelectiveAck{} case ctReconfig: - c = &chunkReconfig{} + dataChunk = &chunkReconfig{} case ctForwardTSN: - c = &chunkForwardTSN{} + dataChunk = &chunkForwardTSN{} case ctError: - c = &chunkError{} + dataChunk = &chunkError{} case ctShutdown: - c = &chunkShutdown{} + dataChunk = &chunkShutdown{} case ctShutdownAck: - c = &chunkShutdownAck{} + dataChunk = &chunkShutdownAck{} case ctShutdownComplete: - c = &chunkShutdownComplete{} + dataChunk = &chunkShutdownComplete{} default: return fmt.Errorf("%w: %s", ErrUnmarshalUnknownChunkType, chunkType(raw[offset]).String()) } - if err := c.unmarshal(raw[offset:]); err != nil { + if err := dataChunk.unmarshal(raw[offset:]); err != nil { return err } - p.chunks = append(p.chunks, c) - chunkValuePadding := getPadding(c.valueLength()) - offset += chunkHeaderSize + c.valueLength() + chunkValuePadding + p.chunks = append(p.chunks, dataChunk) + chunkValuePadding := getPadding(dataChunk.valueLength()) + offset += chunkHeaderSize + dataChunk.valueLength() + chunkValuePadding } return nil @@ -163,11 +163,11 @@ func (p *packet) marshal(doChecksum bool) ([]byte, error) { if err != nil { return nil, err } - raw = append(raw, chunkRaw...) + raw = append(raw, chunkRaw...) //nolint:makezero // todo:fix paddingNeeded := getPadding(len(raw)) if paddingNeeded != 0 { - raw = append(raw, make([]byte, paddingNeeded)...) + raw = append(raw, make([]byte, paddingNeeded)...) //nolint:makezero // todo:fix } } @@ -189,10 +189,11 @@ func generatePacketChecksum(raw []byte) (sum uint32) { sum = crc32.Update(sum, castagnoliTable, raw[0:8]) sum = crc32.Update(sum, castagnoliTable, fourZeroes[:]) sum = crc32.Update(sum, castagnoliTable, raw[12:]) + return sum } -// String makes packet printable +// String makes packet printable. func (p *packet) String() string { format := `Packet: sourcePort: %d @@ -207,6 +208,7 @@ func (p *packet) String() string { for i, chunk := range p.chunks { res += fmt.Sprintf("Chunk %d:\n %s", i, chunk) } + return res } diff --git a/packet_test.go b/packet_test.go index 40557f3d..b49626ae 100644 --- a/packet_test.go +++ b/packet_test.go @@ -21,19 +21,29 @@ func TestPacketUnmarshal(t *testing.T) { case err != nil: t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) case pkt.sourcePort != defaultSCTPSrcDstPort: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.sourcePort) + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", + defaultSCTPSrcDstPort, pkt.sourcePort, + ) case pkt.destinationPort != defaultSCTPSrcDstPort: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.destinationPort) + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", + defaultSCTPSrcDstPort, pkt.destinationPort, + ) case pkt.verificationTag != 0: - t.Errorf("Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag) + t.Errorf( + "Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", + 0, pkt.verificationTag, + ) } rawChunk := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, - 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, - 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, - 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, - 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, + 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, + 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, + 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, + 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, + 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } if err := pkt.unmarshal(true, rawChunk); err != nil { @@ -53,7 +63,10 @@ func TestPacketMarshal(t *testing.T) { if err != nil { t.Errorf("Marshal failed for SCTP packet with no chunks: %v", err) } else if !bytes.Equal(headerOnly, headerOnlyMarshaled) { - t.Errorf("Unmarshal/Marshaled header only packet did not match \nheaderOnly: % 02x \nheaderOnlyMarshaled % 02x", headerOnly, headerOnlyMarshaled) + t.Errorf( + "Unmarshal/Marshaled header only packet did not match \nheaderOnly: % 02x \nheaderOnlyMarshaled % 02x", + headerOnly, headerOnlyMarshaled, + ) } } diff --git a/param.go b/param.go index c28d7a5b..8ac3e484 100644 --- a/param.go +++ b/param.go @@ -16,8 +16,8 @@ type param interface { // ErrParamTypeUnhandled is returned if unknown parameter type is specified. var ErrParamTypeUnhandled = errors.New("unhandled ParamType") -func buildParam(t paramType, rawParam []byte) (param, error) { - switch t { +func buildParam(typeParam paramType, rawParam []byte) (param, error) { //nolint:cyclop + switch typeParam { case forwardTSNSupp: return (¶mForwardTSNSupported{}).unmarshal(rawParam) case supportedExt: @@ -41,6 +41,6 @@ func buildParam(t paramType, rawParam []byte) (param, error) { case zeroChecksumAcceptable: return (¶mZeroChecksumAcceptable{}).unmarshal(rawParam) default: - return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, t) + return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, typeParam) } } diff --git a/param_ecn_capable.go b/param_ecn_capable.go index d6cd0d32..65f269a8 100644 --- a/param_ecn_capable.go +++ b/param_ecn_capable.go @@ -10,6 +10,7 @@ type paramECNCapable struct { func (r *paramECNCapable) marshal() ([]byte, error) { r.typ = ecnCapable r.raw = []byte{} + return r.paramHeader.marshal() } @@ -18,5 +19,6 @@ func (r *paramECNCapable) unmarshal(raw []byte) (param, error) { if err != nil { return nil, err } + return r, nil } diff --git a/param_forward_tsn_supported.go b/param_forward_tsn_supported.go index 655a93ff..89ef51aa 100644 --- a/param_forward_tsn_supported.go +++ b/param_forward_tsn_supported.go @@ -19,6 +19,7 @@ type paramForwardTSNSupported struct { func (f *paramForwardTSNSupported) marshal() ([]byte, error) { f.typ = forwardTSNSupp f.raw = []byte{} + return f.paramHeader.marshal() } @@ -27,5 +28,6 @@ func (f *paramForwardTSNSupported) unmarshal(raw []byte) (param, error) { if err != nil { return nil, err } + return f, nil } diff --git a/param_heartbeat_info.go b/param_heartbeat_info.go index 06f70486..f7ab5128 100644 --- a/param_heartbeat_info.go +++ b/param_heartbeat_info.go @@ -11,6 +11,7 @@ type paramHeartbeatInfo struct { func (h *paramHeartbeatInfo) marshal() ([]byte, error) { h.typ = heartbeatInfo h.raw = h.heartbeatInformation + return h.paramHeader.marshal() } @@ -20,5 +21,6 @@ func (h *paramHeartbeatInfo) unmarshal(raw []byte) (param, error) { return nil, err } h.heartbeatInformation = h.raw + return h, nil } diff --git a/param_outgoing_reset_request.go b/param_outgoing_reset_request.go index 45ee28f9..7a39b94d 100644 --- a/param_outgoing_reset_request.go +++ b/param_outgoing_reset_request.go @@ -55,7 +55,7 @@ type paramOutgoingResetRequest struct { streamIdentifiers []uint16 } -// Outgoing reset request parameter errors +// Outgoing reset request parameter errors. var ( ErrSSNResetRequestParamTooShort = errors.New("outgoing SSN reset request parameter too short") ) @@ -69,6 +69,7 @@ func (r *paramOutgoingResetRequest) marshal() ([]byte, error) { for i, sID := range r.streamIdentifiers { binary.BigEndian.PutUint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:], sID) } + return r.paramHeader.marshal() } diff --git a/param_outgoing_reset_request_test.go b/param_outgoing_reset_request_test.go index 177cb438..a9642b34 100644 --- a/param_outgoing_reset_request_test.go +++ b/param_outgoing_reset_request_test.go @@ -10,7 +10,10 @@ import ( ) func testChunkReconfigParamA() []byte { - return []byte{0x0, 0xd, 0x0, 0x16, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6} + return []byte{ + 0x00, 0x0d, 0x00, 0x16, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x04, 0x00, 0x05, 0x00, 0x06, + } } func testChunkReconfigParamB() []byte { diff --git a/param_random.go b/param_random.go index 4a80aa01..a4b68450 100644 --- a/param_random.go +++ b/param_random.go @@ -11,6 +11,7 @@ type paramRandom struct { func (r *paramRandom) marshal() ([]byte, error) { r.typ = random r.raw = r.randomData + return r.paramHeader.marshal() } @@ -20,5 +21,6 @@ func (r *paramRandom) unmarshal(raw []byte) (param, error) { return nil, err } r.randomData = r.raw + return r, nil } diff --git a/param_reconfig_response.go b/param_reconfig_response.go index e35f16a1..54e14179 100644 --- a/param_reconfig_response.go +++ b/param_reconfig_response.go @@ -48,7 +48,7 @@ const ( reconfigResultInProgress reconfigResult = 6 ) -// Reconfiguration response errors +// Reconfiguration response errors. var ( ErrReconfigRespParamTooShort = errors.New("reconfig response parameter too short") ) diff --git a/param_state_cookie.go b/param_state_cookie.go index dbf8992e..e7dfb915 100644 --- a/param_state_cookie.go +++ b/param_state_cookie.go @@ -31,6 +31,7 @@ func newRandomStateCookie() (*paramStateCookie, error) { func (s *paramStateCookie) marshal() ([]byte, error) { s.typ = stateCookie s.raw = s.cookie + return s.paramHeader.marshal() } @@ -40,10 +41,11 @@ func (s *paramStateCookie) unmarshal(raw []byte) (param, error) { return nil, err } s.cookie = s.raw + return s, nil } -// String makes paramStateCookie printable +// String makes paramStateCookie printable. func (s *paramStateCookie) String() string { return fmt.Sprintf("%s: %s", s.paramHeader, s.cookie) } diff --git a/param_zero_checksum.go b/param_zero_checksum.go index ddf7c5af..29028926 100644 --- a/param_zero_checksum.go +++ b/param_zero_checksum.go @@ -27,7 +27,7 @@ type paramZeroChecksumAcceptable struct { edmid uint32 } -// Zero Checksum parameter error +// Zero Checksum parameter error. var ( ErrZeroChecksumParamTooShort = errors.New("zero checksum parameter too short") ) @@ -40,6 +40,7 @@ func (r *paramZeroChecksumAcceptable) marshal() ([]byte, error) { r.typ = zeroChecksumAcceptable r.raw = make([]byte, 4) binary.BigEndian.PutUint32(r.raw, r.edmid) + return r.paramHeader.marshal() } @@ -52,5 +53,6 @@ func (r *paramZeroChecksumAcceptable) unmarshal(raw []byte) (param, error) { return nil, ErrZeroChecksumParamTooShort } r.edmid = binary.BigEndian.Uint32(r.raw) + return r, nil } diff --git a/paramheader.go b/paramheader.go index 5b694428..8b1986af 100644 --- a/paramheader.go +++ b/paramheader.go @@ -47,7 +47,7 @@ const ( paramHeaderLength = 4 ) -// Parameter header parse errors +// Parameter header parse errors. var ( ErrParamHeaderTooShort = errors.New("param header too short") ErrParamHeaderSelfReportedLengthShorter = errors.New("param self reported length is shorter than header length") @@ -60,7 +60,7 @@ func (p *paramHeader) marshal() ([]byte, error) { rawParam := make([]byte, paramLengthPlusHeader) binary.BigEndian.PutUint16(rawParam[0:], uint16(p.typ)) - binary.BigEndian.PutUint16(rawParam[2:], uint16(paramLengthPlusHeader)) + binary.BigEndian.PutUint16(rawParam[2:], uint16(paramLengthPlusHeader)) //nolint:gosec // G115 copy(rawParam[paramHeaderLength:], p.raw) return rawParam, nil @@ -73,10 +73,16 @@ func (p *paramHeader) unmarshal(raw []byte) error { paramLengthPlusHeader := binary.BigEndian.Uint16(raw[2:]) if int(paramLengthPlusHeader) < paramHeaderLength { - return fmt.Errorf("%w: param self reported length (%d) shorter than header length (%d)", ErrParamHeaderSelfReportedLengthShorter, int(paramLengthPlusHeader), paramHeaderLength) + return fmt.Errorf( + "%w: param self reported length (%d) shorter than header length (%d)", + ErrParamHeaderSelfReportedLengthShorter, int(paramLengthPlusHeader), paramHeaderLength, + ) } if len(raw) < int(paramLengthPlusHeader) { - return fmt.Errorf("%w: param length (%d) shorter than its self reported length (%d)", ErrParamHeaderSelfReportedLengthLonger, len(raw), int(paramLengthPlusHeader)) + return fmt.Errorf( + "%w: param length (%d) shorter than its self reported length (%d)", + ErrParamHeaderSelfReportedLengthLonger, len(raw), int(paramLengthPlusHeader), + ) } typ, err := parseParamType(raw[0:]) @@ -95,7 +101,7 @@ func (p *paramHeader) length() int { return p.len } -// String makes paramHeader printable +// String makes paramHeader printable. func (p paramHeader) String() string { return fmt.Sprintf("%s (%d): %s", p.typ, p.len, hex.Dump(p.raw)) } diff --git a/paramtype.go b/paramtype.go index de1b3dd9..c82d2bf7 100644 --- a/paramtype.go +++ b/paramtype.go @@ -9,7 +9,7 @@ import ( "fmt" ) -// paramType represents a SCTP INIT/INITACK parameter +// paramType represents a SCTP INIT/INITACK parameter. type paramType uint16 const ( @@ -43,7 +43,7 @@ const ( adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] ) -// Parameter packet errors +// Parameter packet errors. var ( ErrParamPacketTooShort = errors.New("packet to short") ) @@ -52,10 +52,11 @@ func parseParamType(raw []byte) (paramType, error) { if len(raw) < 2 { return paramType(0), ErrParamPacketTooShort } + return paramType(binary.BigEndian.Uint16(raw)), nil } -func (p paramType) String() string { +func (p paramType) String() string { //nolint:cyclop switch p { case heartbeatInfo: return "Heartbeat Info" diff --git a/payload_queue.go b/payload_queue.go index 92eacecd..34efdde6 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -22,6 +22,7 @@ func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { if q.chunks.Len() > 0 && tsn == q.chunks.Front().tsn { c := q.chunks.PopFront() q.nBytes -= len(c.userData) + return c, true } @@ -38,6 +39,7 @@ func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { if tsn < head || int(tsn-head) >= length { return nil, false } + return q.chunks.At(int(tsn - head)), true } diff --git a/payload_queue_test.go b/payload_queue_test.go index a5982b48..b41bb53d 100644 --- a/payload_queue_test.go +++ b/payload_queue_test.go @@ -57,7 +57,7 @@ func TestPayloadQueue(t *testing.T) { t.Run("markAllToRetrasmit", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 3; i++ { - pq.pushNoCheck(makePayload(uint32(i+1), 10)) + pq.pushNoCheck(makePayload(uint32(i+1), 10)) //nolint:gosec // G115 } pq.markAsAcked(2) pq.markAllToRetrasmit() @@ -76,7 +76,7 @@ func TestPayloadQueue(t *testing.T) { t.Run("reset retransmit flag on ack", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 4; i++ { - pq.pushNoCheck(makePayload(uint32(i+1), 10)) + pq.pushNoCheck(makePayload(uint32(i+1), 10)) //nolint:gosec // G115 } pq.markAllToRetrasmit() diff --git a/pending_queue.go b/pending_queue.go index 6f70fea2..a9a1552b 100644 --- a/pending_queue.go +++ b/pending_queue.go @@ -27,6 +27,7 @@ func (q *pendingBaseQueue) pop() *chunkPayloadData { } c := q.queue[0] q.queue = q.queue[1:] + return c } @@ -34,6 +35,7 @@ func (q *pendingBaseQueue) get(i int) *chunkPayloadData { if len(q.queue) == 0 || i < 0 || i >= len(q.queue) { return nil } + return q.queue[i] } @@ -51,7 +53,7 @@ type pendingQueue struct { unorderedIsSelected bool } -// Pending queue errors +// Pending queue errors. var ( ErrUnexpectedChuckPoppedUnordered = errors.New("unexpected chunk popped (unordered)") ErrUnexpectedChuckPoppedOrdered = errors.New("unexpected chunk popped (ordered)") @@ -79,26 +81,28 @@ func (q *pendingQueue) peek() *chunkPayloadData { if q.unorderedIsSelected { return q.unorderedQueue.get(0) } + return q.orderedQueue.get(0) } if c := q.unorderedQueue.get(0); c != nil { return c } + return q.orderedQueue.get(0) } -func (q *pendingQueue) pop(c *chunkPayloadData) error { - if q.selected { +func (q *pendingQueue) pop(chunkPayload *chunkPayloadData) error { //nolint:cyclop + if q.selected { //nolint:nestif var popped *chunkPayloadData if q.unorderedIsSelected { popped = q.unorderedQueue.pop() - if popped != c { + if popped != chunkPayload { return ErrUnexpectedChuckPoppedUnordered } } else { popped = q.orderedQueue.pop() - if popped != c { + if popped != chunkPayload { return ErrUnexpectedChuckPoppedOrdered } } @@ -106,12 +110,12 @@ func (q *pendingQueue) pop(c *chunkPayloadData) error { q.selected = false } } else { - if !c.beginningFragment { + if !chunkPayload.beginningFragment { return ErrUnexpectedQState } - if c.unordered { + if chunkPayload.unordered { popped := q.unorderedQueue.pop() - if popped != c { + if popped != chunkPayload { return ErrUnexpectedChuckPoppedUnordered } if !popped.endingFragment { @@ -120,7 +124,7 @@ func (q *pendingQueue) pop(c *chunkPayloadData) error { } } else { popped := q.orderedQueue.pop() - if popped != c { + if popped != chunkPayload { return ErrUnexpectedChuckPoppedOrdered } if !popped.endingFragment { @@ -129,7 +133,8 @@ func (q *pendingQueue) pop(c *chunkPayloadData) error { } } } - q.nBytes -= len(c.userData) + q.nBytes -= len(chunkPayload.userData) + return nil } diff --git a/pending_queue_test.go b/pending_queue_test.go index a4c4cae3..af9d96a8 100644 --- a/pending_queue_test.go +++ b/pending_queue_test.go @@ -17,21 +17,22 @@ const ( ) func makeDataChunk(tsn uint32, unordered bool, frag int) *chunkPayloadData { - var b, e bool + var begin, end bool switch frag { case noFragment: - b = true - e = true + begin = true + end = true case fragBegin: - b = true + begin = true case fragEnd: - e = true + end = true } + return &chunkPayloadData{ tsn: tsn, unordered: unordered, - beginningFragment: b, - endingFragment: e, + beginningFragment: begin, + endingFragment: end, userData: make([]byte, 10), // always 10 bytes } } @@ -124,25 +125,25 @@ func TestPendingQueue(t *testing.T) { pq.push(makeDataChunk(3, true, noFragment)) assert.Equal(t, 40, pq.getNumBytes(), "total bytes mismatch") - c := pq.peek() - err := pq.pop(c) + chunkPayload := pq.peek() + err := pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, uint32(1), c.tsn, "TSN should match") + assert.Equal(t, uint32(1), chunkPayload.tsn, "TSN should match") - c = pq.peek() - err = pq.pop(c) + chunkPayload = pq.peek() + err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, uint32(3), c.tsn, "TSN should match") + assert.Equal(t, uint32(3), chunkPayload.tsn, "TSN should match") - c = pq.peek() - err = pq.pop(c) + chunkPayload = pq.peek() + err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, uint32(0), c.tsn, "TSN should match") + assert.Equal(t, uint32(0), chunkPayload.tsn, "TSN should match") - c = pq.peek() - err = pq.pop(c) + chunkPayload = pq.peek() + err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, uint32(2), c.tsn, "TSN should match") + assert.Equal(t, uint32(2), chunkPayload.tsn, "TSN should match") assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") }) @@ -172,10 +173,10 @@ func TestPendingQueue(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, fragBegin)) - c := pq.peek() - err := pq.pop(c) + chunkPayload := pq.peek() + err := pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, uint32(0), c.tsn, "TSN should match") + assert.Equal(t, uint32(0), chunkPayload.tsn, "TSN should match") pq.push(makeDataChunk(1, true, noFragment)) pq.push(makeDataChunk(2, false, fragMiddle)) @@ -184,10 +185,10 @@ func TestPendingQueue(t *testing.T) { expects := []uint32{2, 3, 1} for _, exp := range expects { - c = pq.peek() - err = pq.pop(c) + chunkPayload = pq.peek() + err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") - assert.Equal(t, exp, c.tsn, "TSN should match") + assert.Equal(t, exp, chunkPayload.tsn, "TSN should match") } }) } diff --git a/queue.go b/queue.go index be8eebb1..1b82edc9 100644 --- a/queue.go +++ b/queue.go @@ -40,6 +40,7 @@ func (q *queue[T]) PopFront() T { q.buf[q.head] = zeroVal q.head = (q.head + 1) % len(q.buf) q.count-- + return ele } diff --git a/queue_test.go b/queue_test.go index 7a19230f..b80c857c 100644 --- a/queue_test.go +++ b/queue_test.go @@ -10,37 +10,37 @@ import ( ) func TestQueue(t *testing.T) { - q := newQueue[int](32) - assert.Zero(t, q.Len()) + queu := newQueue[int](32) + assert.Zero(t, queu.Len()) // test push & pop for i := 1; i < 33; i++ { - q.PushBack(i) + queu.PushBack(i) } - assert.Equal(t, 32, q.Len()) - assert.Equal(t, 5, q.At(4)) + assert.Equal(t, 32, queu.Len()) + assert.Equal(t, 5, queu.At(4)) for i := 1; i < 33; i++ { - assert.Equal(t, i, q.Front()) - assert.Equal(t, i, q.PopFront()) + assert.Equal(t, i, queu.Front()) + assert.Equal(t, i, queu.PopFront()) } - assert.Zero(t, q.Len()) + assert.Zero(t, queu.Len()) - q.PushBack(10) - q.PushBack(11) - assert.Equal(t, 2, q.Len()) - assert.Equal(t, 11, q.At(1)) - assert.Equal(t, 10, q.Front()) - assert.Equal(t, 10, q.PopFront()) - assert.Equal(t, 11, q.PopFront()) + queu.PushBack(10) + queu.PushBack(11) + assert.Equal(t, 2, queu.Len()) + assert.Equal(t, 11, queu.At(1)) + assert.Equal(t, 10, queu.Front()) + assert.Equal(t, 10, queu.PopFront()) + assert.Equal(t, 11, queu.PopFront()) // test grow capacity for i := 0; i < 64; i++ { - q.PushBack(i) + queu.PushBack(i) } - assert.Equal(t, 64, q.Len()) - assert.Equal(t, 2, q.At(2)) + assert.Equal(t, 64, queu.Len()) + assert.Equal(t, 2, queu.At(2)) for i := 0; i < 64; i++ { - assert.Equal(t, i, q.Front()) - assert.Equal(t, i, q.PopFront()) + assert.Equal(t, i, queu.Front()) + assert.Equal(t, i, queu.PopFront()) } } diff --git a/reassembly_queue.go b/reassembly_queue.go index e0d527c0..51a35be9 100644 --- a/reassembly_queue.go +++ b/reassembly_queue.go @@ -22,7 +22,7 @@ func sortChunksBySSN(a []*chunkSet) { }) } -// chunkSet is a set of chunks that share the same SSN +// chunkSet is a set of chunks that share the same SSN. type chunkSet struct { ssn uint16 // used only with the ordered chunks ppi PayloadProtocolIdentifier @@ -51,6 +51,7 @@ func (set *chunkSet) push(chunk *chunkPayloadData) bool { // Check if we now have a complete set complete := set.isComplete() + return complete } @@ -79,7 +80,7 @@ func (set *chunkSet) isComplete() bool { // 3. var lastTSN uint32 - for i, c := range set.chunks { + for i, chunk := range set.chunks { if i > 0 { // Fragments must have contiguous TSN // From RFC 4960 Section 3.3.1: @@ -87,13 +88,13 @@ func (set *chunkSet) isComplete() bool { // used by the receiver to reassemble the message. This means that the // TSNs for each fragment of a fragmented user message MUST be strictly // sequential. - if c.tsn != lastTSN+1 { + if chunk.tsn != lastTSN+1 { // mid or end fragment is missing return false } } - lastTSN = c.tsn + lastTSN = chunk.tsn } return true @@ -124,7 +125,7 @@ func newReassemblyQueue(si uint16) *reassemblyQueue { } } -func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { +func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { //nolint:cyclop var cset *chunkSet if chunk.streamIdentifier != r.si { @@ -143,6 +144,7 @@ func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { // If found, append the complete set to the unordered array if cset != nil { r.unordered = append(r.unordered, cset) + return true } @@ -169,6 +171,7 @@ func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { // for O(1) lookups at the cost of 2x memory. if set.ssn == chunk.streamSequenceNumber && set.chunks[0].isFragmented() { cset = set + break } } @@ -194,17 +197,19 @@ func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { var lastTSN uint32 var found bool - for i, c := range r.unorderedChunks { + for i, chunk := range r.unorderedChunks { // seek beigining - if c.beginningFragment { + if chunk.beginningFragment { startIdx = i nChunks = 1 - lastTSN = c.tsn + lastTSN = chunk.tsn - if c.endingFragment { + if chunk.endingFragment { found = true + break } + continue } @@ -213,16 +218,18 @@ func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { } // Check if contiguous in TSN - if c.tsn != lastTSN+1 { + if chunk.tsn != lastTSN+1 { startIdx = -1 + continue } - lastTSN = c.tsn + lastTSN = chunk.tsn nChunks++ - if c.endingFragment { + if chunk.endingFragment { found = true + break } } @@ -261,6 +268,7 @@ func (r *reassemblyQueue) isReadable() bool { } } } + return false } @@ -318,6 +326,7 @@ func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { for _, c := range set.chunks { r.subtractNumBytes(len(c.userData)) } + continue } } @@ -354,13 +363,13 @@ func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { func (r *reassemblyQueue) subtractNumBytes(nBytes int) { cur := atomic.LoadUint64(&r.nBytes) - if int(cur) >= nBytes { - atomic.AddUint64(&r.nBytes, -uint64(nBytes)) + if int(cur) >= nBytes { //nolint:gosec // G115 + atomic.AddUint64(&r.nBytes, -uint64(nBytes)) //nolint:gosec // G115 } else { atomic.StoreUint64(&r.nBytes, 0) } } func (r *reassemblyQueue) getNumBytes() int { - return int(atomic.LoadUint64(&r.nBytes)) + return int(atomic.LoadUint64(&r.nBytes)) //nolint:gosec // G115 } diff --git a/reassembly_queue_test.go b/reassembly_queue_test.go index 02478f45..ee9f5e7c 100644 --- a/reassembly_queue_test.go +++ b/reassembly_queue_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestReassemblyQueue(t *testing.T) { +func TestReassemblyQueue(t *testing.T) { //nolint:maintidx t.Run("ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) diff --git a/receive_payload_queue.go b/receive_payload_queue.go index 2a4f2bbf..30c9109d 100644 --- a/receive_payload_queue.go +++ b/receive_payload_queue.go @@ -20,6 +20,7 @@ type receivePayloadQueue struct { func newReceivePayloadQueue(maxTSNOffset uint32) *receivePayloadQueue { maxTSNOffset = ((maxTSNOffset + 63) / 64) * 64 + return &receivePayloadQueue{ tsnBitmask: make([]uint64, maxTSNOffset/64), maxTSNOffset: maxTSNOffset, @@ -42,6 +43,7 @@ func (q *receivePayloadQueue) hasChunk(tsn uint32) bool { } index, offset := int(tsn/64)%len(q.tsnBitmask), tsn%64 + return q.tsnBitmask[index]&(1<> uint64(start)) + i := bits.TrailingZeros64(val >> uint64(start)) //nolint:gosec // G115 + return i + start, i+start < end } diff --git a/receive_payload_queue_test.go b/receive_payload_queue_test.go index 43056715..3501d77b 100644 --- a/receive_payload_queue_test.go +++ b/receive_payload_queue_test.go @@ -12,47 +12,47 @@ import ( func TestReceivePayloadQueue(t *testing.T) { maxOffset := uint32(512) - q := newReceivePayloadQueue(maxOffset) + payloadQueue := newReceivePayloadQueue(maxOffset) initTSN := uint32(math.MaxUint32 - 10) - q.init(initTSN - 2) - assert.Equal(t, initTSN-2, q.getcumulativeTSN()) - assert.Zero(t, q.size()) - _, ok := q.getLastTSNReceived() + payloadQueue.init(initTSN - 2) + assert.Equal(t, initTSN-2, payloadQueue.getcumulativeTSN()) + assert.Zero(t, payloadQueue.size()) + _, ok := payloadQueue.getLastTSNReceived() assert.False(t, ok) - assert.Empty(t, q.getGapAckBlocks()) + assert.Empty(t, payloadQueue.getGapAckBlocks()) // force pop empy queue to advance cumulative TSN - assert.False(t, q.pop(true)) - assert.Equal(t, initTSN-1, q.getcumulativeTSN()) - assert.Zero(t, q.size()) - assert.Empty(t, q.getGapAckBlocks()) + assert.False(t, payloadQueue.pop(true)) + assert.Equal(t, initTSN-1, payloadQueue.getcumulativeTSN()) + assert.Zero(t, payloadQueue.size()) + assert.Empty(t, payloadQueue.getGapAckBlocks()) nextTSN := initTSN + maxOffset - 1 - assert.True(t, q.push(nextTSN)) - assert.Equal(t, 1, q.size()) - lastTSN, ok := q.getLastTSNReceived() + assert.True(t, payloadQueue.push(nextTSN)) + assert.Equal(t, 1, payloadQueue.size()) + lastTSN, ok := payloadQueue.getLastTSNReceived() assert.True(t, lastTSN == nextTSN && ok, "lastTSN:%d, ok:%t", lastTSN, ok) - assert.True(t, q.hasChunk(nextTSN)) + assert.True(t, payloadQueue.hasChunk(nextTSN)) - assert.True(t, q.push(initTSN)) - assert.False(t, q.canPush(initTSN-1)) - assert.False(t, q.canPush(initTSN+maxOffset)) - assert.False(t, q.push(initTSN+maxOffset)) - assert.True(t, q.canPush(nextTSN-1)) - assert.Equal(t, 2, q.size()) + assert.True(t, payloadQueue.push(initTSN)) + assert.False(t, payloadQueue.canPush(initTSN-1)) + assert.False(t, payloadQueue.canPush(initTSN+maxOffset)) + assert.False(t, payloadQueue.push(initTSN+maxOffset)) + assert.True(t, payloadQueue.canPush(nextTSN-1)) + assert.Equal(t, 2, payloadQueue.size()) - gaps := q.getGapAckBlocks() + gaps := payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ {start: uint16(1), end: uint16(1)}, {start: uint16(maxOffset), end: uint16(maxOffset)}, }, gaps) - assert.True(t, q.pop(false)) - assert.Equal(t, 1, q.size()) - assert.Equal(t, initTSN, q.cumulativeTSN) - assert.False(t, q.pop(false)) - assert.Equal(t, initTSN, q.cumulativeTSN) + assert.True(t, payloadQueue.pop(false)) + assert.Equal(t, 1, payloadQueue.size()) + assert.Equal(t, initTSN, payloadQueue.cumulativeTSN) + assert.False(t, payloadQueue.pop(false)) + assert.Equal(t, initTSN, payloadQueue.cumulativeTSN) - size := q.size() + size := payloadQueue.size() // push tsn with two gap // tsnRange [[start,end]...] tsnRange := [][]uint32{ @@ -61,82 +61,89 @@ func TestReceivePayloadQueue(t *testing.T) { } range0, range1 := tsnRange[0], tsnRange[1] for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { - assert.True(t, q.push(tsn)) - assert.False(t, q.pop(false)) - assert.True(t, q.hasChunk(tsn)) + assert.True(t, payloadQueue.push(tsn)) + assert.False(t, payloadQueue.pop(false)) + assert.True(t, payloadQueue.hasChunk(tsn)) } size += int(range0[1] - range0[0] + 1) for tsn := range1[0]; sna32LTE(tsn, range1[1]); tsn++ { - assert.True(t, q.push(tsn)) - assert.False(t, q.pop(false)) - assert.True(t, q.hasChunk(tsn)) + assert.True(t, payloadQueue.push(tsn)) + assert.False(t, payloadQueue.pop(false)) + assert.True(t, payloadQueue.hasChunk(tsn)) } size += int(range1[1] - range1[0] + 1) - assert.Equal(t, size, q.size()) - gaps = q.getGapAckBlocks() + assert.Equal(t, size, payloadQueue.size()) + gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ + //nolint:gosec // G115 {start: uint16(range0[0] - initTSN), end: uint16(range0[1] - initTSN)}, + //nolint:gosec // G115 {start: uint16(range1[0] - initTSN), end: uint16(range1[1] - initTSN)}, + //nolint:gosec // G115 {start: uint16(nextTSN - initTSN), end: uint16(nextTSN - initTSN)}, }, gaps) // push duplicate tsns - assert.False(t, q.push(initTSN-2)) - assert.False(t, q.push(range0[0])) - assert.False(t, q.push(range0[0])) - assert.False(t, q.push(nextTSN)) - assert.False(t, q.push(initTSN+maxOffset+1)) - duplicates := q.popDuplicates() + assert.False(t, payloadQueue.push(initTSN-2)) + assert.False(t, payloadQueue.push(range0[0])) + assert.False(t, payloadQueue.push(range0[0])) + assert.False(t, payloadQueue.push(nextTSN)) + assert.False(t, payloadQueue.push(initTSN+maxOffset+1)) + duplicates := payloadQueue.popDuplicates() assert.EqualValues(t, []uint32{initTSN - 2, range0[0], range0[0], nextTSN}, duplicates) // force pop to advance cumulativeTSN to fill the gap [initTSN, initTSN+4] for tsn := initTSN + 1; sna32LT(tsn, range0[0]); tsn++ { - assert.False(t, q.pop(true)) - assert.Equal(t, size, q.size()) - assert.Equal(t, tsn, q.cumulativeTSN) + assert.False(t, payloadQueue.pop(true)) + assert.Equal(t, size, payloadQueue.size()) + assert.Equal(t, tsn, payloadQueue.cumulativeTSN) } for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { - assert.True(t, q.pop(false)) - assert.Equal(t, tsn, q.getcumulativeTSN()) + assert.True(t, payloadQueue.pop(false)) + assert.Equal(t, tsn, payloadQueue.getcumulativeTSN()) } - assert.False(t, q.pop(false)) - cumulativeTSN := q.getcumulativeTSN() + assert.False(t, payloadQueue.pop(false)) + cumulativeTSN := payloadQueue.getcumulativeTSN() assert.Equal(t, range0[1], cumulativeTSN) - gaps = q.getGapAckBlocks() + gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ + //nolint:gosec // G115 {start: uint16(range1[0] - range0[1]), end: uint16(range1[1] - range0[1])}, + //nolint:gosec // G115 {start: uint16(nextTSN - range0[1]), end: uint16(nextTSN - range0[1])}, }, gaps) // fill the gap with received tsn for tsn := range0[1] + 1; sna32LT(tsn, range1[0]); tsn++ { - assert.True(t, q.push(tsn), tsn) + assert.True(t, payloadQueue.push(tsn), tsn) } for tsn := range0[1] + 1; sna32LTE(tsn, range1[1]); tsn++ { - assert.True(t, q.pop(false)) - assert.Equal(t, tsn, q.getcumulativeTSN()) + assert.True(t, payloadQueue.pop(false)) + assert.Equal(t, tsn, payloadQueue.getcumulativeTSN()) } - assert.False(t, q.pop(false)) - assert.Equal(t, range1[1], q.getcumulativeTSN()) - gaps = q.getGapAckBlocks() + assert.False(t, payloadQueue.pop(false)) + assert.Equal(t, range1[1], payloadQueue.getcumulativeTSN()) + gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ + //nolint:gosec // G115 {start: uint16(nextTSN - range1[1]), end: uint16(nextTSN - range1[1])}, }, gaps) // gap block cross end tsn endTSN := maxOffset - 1 for tsn := nextTSN + 1; sna32LTE(tsn, endTSN); tsn++ { - assert.True(t, q.push(tsn)) + assert.True(t, payloadQueue.push(tsn)) } - gaps = q.getGapAckBlocks() + gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ + //nolint:gosec // G115 {start: uint16(nextTSN - range1[1]), end: uint16(endTSN - range1[1])}, }, gaps) - assert.NotEmpty(t, q.getGapAckBlocksString()) + assert.NotEmpty(t, payloadQueue.getGapAckBlocksString()) } func TestBitfunc(t *testing.T) { diff --git a/rtx_timer.go b/rtx_timer.go index 1fea3931..65838249 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -10,25 +10,25 @@ import ( ) const ( - // RTO.Initial in msec + // RTO.Initial in msec. rtoInitial float64 = 1.0 * 1000 - // RTO.Min in msec + // RTO.Min in msec. rtoMin float64 = 1.0 * 1000 - // RTO.Max in msec + // RTO.Max in msec. defaultRTOMax float64 = 60.0 * 1000 - // RTO.Alpha + // RTO.Alpha. rtoAlpha float64 = 0.125 - // RTO.Beta + // RTO.Beta. rtoBeta float64 = 0.25 - // Max.Init.Retransmits: + // Max.Init.Retransmits. maxInitRetrans uint = 8 - // Path.Max.Retrans + // Path.Max.Retrans. pathMaxRetrans uint = 5 noMaxRetrans uint = 0 @@ -54,6 +54,7 @@ func newRTOManager(rtoMax float64) *rtoManager { if mgr.rtoMax == 0 { mgr.rtoMax = defaultRTOMax } + return &mgr } @@ -76,6 +77,7 @@ func (m *rtoManager) setNewRTT(rtt float64) float64 { m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt } m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), m.rtoMax) + return m.srtt } @@ -101,7 +103,7 @@ func (m *rtoManager) reset() { m.rto = rtoInitial } -// set RTO value for testing +// set RTO value for testing. func (m *rtoManager) setRTO(rto float64, noUpdate bool) { m.mutex.Lock() defer m.mutex.Unlock() @@ -126,7 +128,7 @@ const ( rtxTimerClosed ) -// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 +// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. type rtxTimer struct { timer *time.Timer observer rtxTimerObserver @@ -157,11 +159,13 @@ func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, } timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout) timer.timer.Stop() + return &timer } func (t *rtxTimer) calculateNextTimeout() time.Duration { timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax) + return time.Duration(timeout) * time.Millisecond } @@ -199,6 +203,7 @@ func (t *rtxTimer) start(rto float64) bool { t.state = rtxTimerStarted t.pending++ t.timer.Reset(t.calculateNextTimeout()) + return true } @@ -216,7 +221,7 @@ func (t *rtxTimer) stop() { } // closes the timer. this is similar to stop() but subsequent start() call -// will fail (the timer is no longer usable) +// will fail (the timer is no longer usable). func (t *rtxTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() @@ -228,7 +233,7 @@ func (t *rtxTimer) close() { } // isRunning tests if the timer is running. -// Debug purpose only +// Debug purpose only. func (t *rtxTimer) isRunning() bool { t.mutex.Lock() defer t.mutex.Unlock() @@ -244,7 +249,9 @@ func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 { // to this doubling operation. if nRtos < 31 { m := 1 << nRtos + return math.Min(rto*float64(m), rtoMax) } + return rtoMax } diff --git a/rtx_timer_test.go b/rtx_timer_test.go index 2a7da5af..4e3e3d2a 100644 --- a/rtx_timer_test.go +++ b/rtx_timer_test.go @@ -115,7 +115,7 @@ func (o *testTimerObserver) onRetransmissionFailure(id int) { o.onRtxFailure(id) } -func TestRtxTimer(t *testing.T) { +func TestRtxTimer(t *testing.T) { //nolint:maintidx t.Run("callback interval", func(t *testing.T) { timerID := 0 var nCbs int32 diff --git a/stream.go b/stream.go index 3be12912..9f88b3eb 100644 --- a/stream.go +++ b/stream.go @@ -17,11 +17,11 @@ import ( ) const ( - // ReliabilityTypeReliable is used for reliable transmission + // ReliabilityTypeReliable is used for reliable transmission. ReliabilityTypeReliable byte = 0 - // ReliabilityTypeRexmit is used for partial reliability by retransmission count + // ReliabilityTypeRexmit is used for partial reliability by retransmission count. ReliabilityTypeRexmit byte = 1 - // ReliabilityTypeTimed is used for partial reliability by retransmission duration + // ReliabilityTypeTimed is used for partial reliability by retransmission duration. ReliabilityTypeTimed byte = 2 ) @@ -29,7 +29,7 @@ const ( // This field identifies the state of stream. type StreamState int -// StreamState enums +// StreamState enums. const ( StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen StreamStateClosing // Outgoing stream is being reset @@ -45,17 +45,18 @@ func (ss StreamState) String() string { case StreamStateClosed: return "closed" } + return "unknown" } -// SCTP stream errors +// SCTP stream errors. var ( ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size") ErrStreamClosed = errors.New("stream closed") ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded) ) -// Stream represents an SCTP stream +// Stream represents an SCTP stream. type Stream struct { association *Association lock sync.RWMutex @@ -83,6 +84,7 @@ type Stream struct { func (s *Stream) StreamIdentifier() uint16 { s.lock.RLock() defer s.lock.RUnlock() + return s.streamIdentifier } @@ -114,14 +116,15 @@ func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint3 // otherwise. func (s *Stream) Read(p []byte) (int, error) { n, _, err := s.ReadSCTP(p) + return n, err } -// ReadSCTP reads a packet of len(p) bytes and returns the associated Payload +// ReadSCTP reads a packet of len(payload) bytes and returns the associated Payload // Protocol Identifier. // Returns EOF when the stream is reset or an error if the stream is closed // otherwise. -func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { +func (s *Stream) ReadSCTP(payload []byte) (int, PayloadProtocolIdentifier, error) { s.lock.Lock() defer s.lock.Unlock() @@ -134,7 +137,7 @@ func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { }() for { - n, ppi, err := s.reassemblyQueue.read(p) + n, ppi, err := s.reassemblyQueue.read(payload) if err == nil { return n, ppi, nil } else if errors.Is(err, io.ErrShortBuffer) { @@ -150,7 +153,7 @@ func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { } } -// SetReadDeadline sets the read deadline in an identical way to net.Conn +// SetReadDeadline sets the read deadline in an identical way to net.Conn. func (s *Stream) SetReadDeadline(deadline time.Time) error { s.lock.Lock() defer s.lock.Unlock() @@ -175,6 +178,7 @@ func (s *Stream) SetReadDeadline(deadline time.Time) error { select { case <-readTimeoutCancel: t.Stop() + return case <-t.C: select { @@ -193,6 +197,7 @@ func (s *Stream) SetReadDeadline(deadline time.Time) error { } }(s.readTimeoutCancel) } + return nil } @@ -258,16 +263,17 @@ func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) { } } -// Write writes len(p) bytes from p with the default Payload Protocol Identifier -func (s *Stream) Write(p []byte) (n int, err error) { +// Write writes len(payload) bytes from payload with the default Payload Protocol Identifier. +func (s *Stream) Write(payload []byte) (n int, err error) { ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType))) - return s.WriteSCTP(p, ppi) + + return s.WriteSCTP(payload, ppi) } -// WriteSCTP writes len(p) bytes from p to the DTLS connection -func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) { +// WriteSCTP writes len(payload) bytes from payload to the DTLS connection. +func (s *Stream) WriteSCTP(payload []byte, ppi PayloadProtocolIdentifier) (int, error) { maxMessageSize := s.association.MaxMessageSize() - if len(p) > int(maxMessageSize) { + if len(payload) > int(maxMessageSize) { return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, maxMessageSize) } @@ -281,8 +287,8 @@ func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) if s.association.isBlockWrite() { s.writeLock.Lock() } - chunks, unordered := s.packetize(p, ppi) - n := len(p) + chunks, unordered := s.packetize(payload, ppi) + n := len(payload) err := s.association.sendPayloadData(s.writeDeadline, chunks) if err != nil { s.lock.Lock() @@ -296,20 +302,24 @@ func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) if s.association.isBlockWrite() { s.writeLock.Unlock() } + return n, err } -// SetWriteDeadline sets the write deadline in an identical way to net.Conn, it will only work for blocking writes +// SetWriteDeadline sets the write deadline in an identical way to net.Conn, +// it will only work for blocking writes. func (s *Stream) SetWriteDeadline(deadline time.Time) error { s.writeDeadline.Set(deadline) + return nil } -// SetDeadline sets the read and write deadlines in an identical way to net.Conn +// SetDeadline sets the read and write deadlines in an identical way to net.Conn. func (s *Stream) SetDeadline(t time.Time) error { if err := s.SetReadDeadline(t); err != nil { return err } + return s.SetWriteDeadline(t) } @@ -317,8 +327,8 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkP s.lock.Lock() defer s.lock.Unlock() - i := uint32(0) - remaining := uint32(len(raw)) + offset := uint32(0) + remaining := uint32(len(raw)) //nolint:gosec // G115 // From draft-ietf-rtcweb-data-protocol-09, section 6: // All Data Channel Establishment Protocol messages MUST be sent using @@ -333,13 +343,13 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkP // Copy the userdata since we'll have to store it until acked // and the caller may re-use the buffer in the mean time userData := make([]byte, fragmentSize) - copy(userData, raw[i:i+fragmentSize]) + copy(userData, raw[offset:offset+fragmentSize]) chunk := &chunkPayloadData{ streamIdentifier: s.streamIdentifier, userData: userData, unordered: unordered, - beginningFragment: i == 0, + beginningFragment: offset == 0, endingFragment: remaining-fragmentSize == 0, immediateSack: false, payloadType: ppi, @@ -354,7 +364,7 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkP chunks = append(chunks, chunk) remaining -= fragmentSize - i += fragmentSize + offset += fragmentSize } // RFC 4960 Sec 6.6 @@ -387,8 +397,10 @@ func (s *Stream) Close() error { s.state = StreamStateClosed } s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String()) + return s.streamIdentifier, true } + return s.streamIdentifier, false }(); resetOutbound { // Reset the outgoing stream @@ -459,6 +471,7 @@ func (s *Stream) onBufferReleased(nBytesReleased int) { f := s.onBufferedAmountLow s.lock.Unlock() f() + return } @@ -499,5 +512,6 @@ func (s *Stream) onInboundStreamReset() { func (s *Stream) State() StreamState { s.lock.RLock() defer s.lock.RUnlock() + return s.state } diff --git a/stream_test.go b/stream_test.go index 91198c61..ba82dc2d 100644 --- a/stream_test.go +++ b/stream_test.go @@ -26,47 +26,47 @@ func TestSessionBufferedAmount(t *testing.T) { }) t.Run("OnBufferedAmountLow", func(t *testing.T) { - s := &Stream{ + stream := &Stream{ log: logging.NewDefaultLoggerFactory().NewLogger("sctp-test"), } - s.bufferedAmount = 4096 - s.SetBufferedAmountLowThreshold(2048) + stream.bufferedAmount = 4096 + stream.SetBufferedAmountLowThreshold(2048) nCbs := 0 - s.OnBufferedAmountLow(func() { + stream.OnBufferedAmountLow(func() { nCbs++ }) // Negative value should be ignored (by design) - s.onBufferReleased(-32) // bufferedAmount = 3072 - assert.Equal(t, uint64(4096), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(-32) // bufferedAmount = 3072 + assert.Equal(t, uint64(4096), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to above, no callback - s.onBufferReleased(1024) // bufferedAmount = 3072 - assert.Equal(t, uint64(3072), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(1024) // bufferedAmount = 3072 + assert.Equal(t, uint64(3072), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to equal, callback should be made - s.onBufferReleased(1024) // bufferedAmount = 2048 - assert.Equal(t, uint64(2048), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(1024) // bufferedAmount = 2048 + assert.Equal(t, uint64(2048), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Eaual to below, no callback - s.onBufferReleased(1024) // bufferedAmount = 1024 - assert.Equal(t, uint64(1024), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(1024) // bufferedAmount = 1024 + assert.Equal(t, uint64(1024), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Blow to below, no callback - s.onBufferReleased(1024) // bufferedAmount = 0 - assert.Equal(t, uint64(0), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(1024) // bufferedAmount = 0 + assert.Equal(t, uint64(0), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Capped at 0, no callback - s.onBufferReleased(1024) // bufferedAmount = 0 - assert.Equal(t, uint64(0), s.BufferedAmount(), "unexpected bufferedAmount") + stream.onBufferReleased(1024) // bufferedAmount = 0 + assert.Equal(t, uint64(0), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") }) } diff --git a/util.go b/util.go index b8afc123..257e289a 100644 --- a/util.go +++ b/util.go @@ -16,10 +16,11 @@ func padByte(in []byte, cnt int) []byte { cnt = 0 } padding := make([]byte, cnt) + return append(in, padding...) } -// Serial Number Arithmetic (RFC 1982) +// Serial Number Arithmetic (RFC 1982). func sna32LT(i1, i2 uint32) bool { return (i1 < i2 && i2-i1 < 1<<31) || (i1 > i2 && i1-i2 > 1<<31) } diff --git a/vnet_test.go b/vnet_test.go index 90152171..92a5b639 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -52,7 +52,7 @@ func (venv *vNetEnv) dropNextCookieAckChunk(numToDrop int) { venv.numToDropCookieAck = numToDrop } -func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { +func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { //nolint:cyclop log := cfg.log var venv *vNetEnv @@ -94,6 +94,7 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { toDrop = true venv.numToDropData-- log.Infof("Chunk filter: drop TSN %d", tsn) + break loop } } @@ -102,6 +103,7 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { toDrop = true venv.numToDropReconfig-- log.Infof("Chunk filter: drop RECONFIG %s", chunk.String()) + break loop } case *chunkCookieEcho: @@ -109,6 +111,7 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { toDrop = true venv.numToDropCookieEcho-- log.Infof("Chunk filter: drop %s", chunk.String()) + break loop } case *chunkCookieAck: @@ -116,10 +119,12 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { toDrop = true venv.numToDropCookieAck-- log.Infof("Chunk filter: drop %s", chunk.String()) + break loop } } } + return !toDrop } } @@ -163,7 +168,9 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { return venv, nil } -func testRwndFull(t *testing.T, unordered bool) { +func testRwndFull(t *testing.T, unordered bool) { //nolint:cyclop + t.Helper() + loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") @@ -196,7 +203,7 @@ func testRwndFull(t *testing.T, unordered bool) { maxReceiveBufferSize := uint32(64 * 1024) msgSize := int(float32(maxReceiveBufferSize)/2) + int(initialMTU) msg := make([]byte, msgSize) - rand.Read(msg) // nolint:errcheck,gosec + rand.Read(msg) // nolint:errcheck,gosec,staticcheck // TODO: fix? go func() { defer close(serverShutDown) @@ -397,8 +404,10 @@ func TestRwndFull(t *testing.T) { }) } -func TestStreamClose(t *testing.T) { +func TestStreamClose(t *testing.T) { //nolint:cyclop loopBackTest := func(t *testing.T, dropReconfigChunk bool) { + t.Helper() + lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -468,6 +477,7 @@ func TestStreamClose(t *testing.T) { log.Infof("server: Read returned %v", errRead) _ = stream.Close() // nolint:errcheck assert.Equal(t, StreamStateClosed, stream.State()) + break } @@ -521,6 +531,7 @@ func TestStreamClose(t *testing.T) { if err2 != nil { log.Infof("client: Read returned %v", err2) assert.Equal(t, StreamStateClosed, stream.State()) + break } @@ -581,7 +592,7 @@ func TestStreamClose(t *testing.T) { // and confirmes the fix. // To reproduce the case mentioned above: // * Use simultaneous-open (SCTP) -// * Drop both of the first COOKIE-ECHO and COOKIE-ACK +// * Drop both of the first COOKIE-ECHO and COOKIE-ACK. func TestCookieEchoRetransmission(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop()