From 6e08a5eec9e421ae45ac6b322baba3a77b9dde6a Mon Sep 17 00:00:00 2001 From: Cecylia Bocovich Date: Tue, 25 Feb 2020 17:49:39 -0500 Subject: [PATCH] Add synchronization to payloadQueue map access Map access are not concurrency safe, and the chunkmap of a payloadQueue is accessed by multiple goroutines. This commit adds a RWMutex to chunkmap accesses. --- payload_queue.go | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/payload_queue.go b/payload_queue.go index 8bf4f5f1..f6046b07 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -3,30 +3,36 @@ package sctp import ( "fmt" "sort" + "sync" ) type payloadQueue struct { chunkMap map[uint32]*chunkPayloadData + mapMutex *sync.RWMutex sorted []uint32 dupTSN []uint32 nBytes int } func newPayloadQueue() *payloadQueue { - return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}} + return &payloadQueue{ + chunkMap: map[uint32]*chunkPayloadData{}, + mapMutex: &sync.RWMutex{}, + } } func (q *payloadQueue) updateSortedKeys() { if q.sorted != nil { return } - + q.mapMutex.RLock() q.sorted = make([]uint32, len(q.chunkMap)) i := 0 for k := range q.chunkMap { q.sorted[i] = k i++ } + q.mapMutex.RUnlock() sort.Slice(q.sorted, func(i, j int) bool { return sna32LT(q.sorted[i], q.sorted[j]) @@ -34,6 +40,8 @@ func (q *payloadQueue) updateSortedKeys() { } func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { + q.mapMutex.RLock() + defer q.mapMutex.RUnlock() _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { return false @@ -42,6 +50,8 @@ func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { } func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { + q.mapMutex.Lock() + defer q.mapMutex.Unlock() q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) q.sorted = nil @@ -51,6 +61,8 @@ func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { // older than our cumulativeTSN marker, it will be recored as duplications, // which can later be retrieved using popDuplicates. func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { + q.mapMutex.Lock() + defer q.mapMutex.Unlock() _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { // Found the packet, log in dups @@ -67,6 +79,8 @@ func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { q.updateSortedKeys() + q.mapMutex.Lock() + defer q.mapMutex.Unlock() if len(q.chunkMap) > 0 && tsn == q.sorted[0] { q.sorted = q.sorted[1:] @@ -82,6 +96,8 @@ func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { // get returns reference to chunkPayloadData with the given TSN value. func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { + q.mapMutex.RLock() + defer q.mapMutex.RUnlock() c, ok := q.chunkMap[tsn] return c, ok } @@ -140,6 +156,8 @@ func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string { func (q *payloadQueue) markAsAcked(tsn uint32) int { var nBytesAcked int + q.mapMutex.RLock() + defer q.mapMutex.RUnlock() if c, ok := q.chunkMap[tsn]; ok { c.acked = true nBytesAcked = len(c.userData) @@ -161,6 +179,8 @@ func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { } func (q *payloadQueue) markAllToRetrasmit() { + q.mapMutex.RLock() + defer q.mapMutex.RUnlock() for _, c := range q.chunkMap { if c.acked || c.abandoned() { continue @@ -174,5 +194,7 @@ func (q *payloadQueue) getNumBytes() int { } func (q *payloadQueue) size() int { + q.mapMutex.RLock() + defer q.mapMutex.RUnlock() return len(q.chunkMap) }