Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add synchronization to payloadQueue map access #115

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions payload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,45 @@ package sctp
import (
"fmt"
"sort"
"sync"
)

type payloadQueue struct {
chunkMap map[uint32]*chunkPayloadData
mapMutex *sync.RWMutex
sorted []uint32
dupTSN []uint32
nBytes int
}

func newPayloadQueue() *payloadQueue {
return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}}
return &payloadQueue{
chunkMap: map[uint32]*chunkPayloadData{},
mapMutex: &sync.RWMutex{},
}
}

func (q *payloadQueue) updateSortedKeys() {
if q.sorted != nil {
return
}

q.mapMutex.RLock()
q.sorted = make([]uint32, len(q.chunkMap))
i := 0
for k := range q.chunkMap {
q.sorted[i] = k
i++
}
q.mapMutex.RUnlock()

sort.Slice(q.sorted, func(i, j int) bool {
return sna32LT(q.sorted[i], q.sorted[j])
})
}

func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
q.mapMutex.RLock()
defer q.mapMutex.RUnlock()
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
return false
Expand All @@ -42,6 +50,8 @@ func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
}

func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) {
q.mapMutex.Lock()
defer q.mapMutex.Unlock()
q.chunkMap[p.tsn] = p
q.nBytes += len(p.userData)
q.sorted = nil
Expand All @@ -51,6 +61,8 @@ func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) {
// older than our cumulativeTSN marker, it will be recored as duplications,
// which can later be retrieved using popDuplicates.
func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool {
q.mapMutex.Lock()
defer q.mapMutex.Unlock()
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
// Found the packet, log in dups
Expand All @@ -67,6 +79,8 @@ func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool {
// pop pops only if the oldest chunk's TSN matches the given TSN.
func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) {
q.updateSortedKeys()
q.mapMutex.Lock()
defer q.mapMutex.Unlock()

if len(q.chunkMap) > 0 && tsn == q.sorted[0] {
q.sorted = q.sorted[1:]
Expand All @@ -82,6 +96,8 @@ func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) {

// get returns reference to chunkPayloadData with the given TSN value.
func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) {
q.mapMutex.RLock()
defer q.mapMutex.RUnlock()
c, ok := q.chunkMap[tsn]
return c, ok
}
Expand Down Expand Up @@ -140,6 +156,8 @@ func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string {

func (q *payloadQueue) markAsAcked(tsn uint32) int {
var nBytesAcked int
q.mapMutex.RLock()
defer q.mapMutex.RUnlock()
if c, ok := q.chunkMap[tsn]; ok {
c.acked = true
nBytesAcked = len(c.userData)
Expand All @@ -161,6 +179,8 @@ func (q *payloadQueue) getLastTSNReceived() (uint32, bool) {
}

func (q *payloadQueue) markAllToRetrasmit() {
q.mapMutex.RLock()
defer q.mapMutex.RUnlock()
for _, c := range q.chunkMap {
if c.acked || c.abandoned() {
continue
Expand All @@ -174,5 +194,7 @@ func (q *payloadQueue) getNumBytes() int {
}

func (q *payloadQueue) size() int {
q.mapMutex.RLock()
defer q.mapMutex.RUnlock()
return len(q.chunkMap)
}