diff --git a/server/ipqueue.go b/server/ipqueue.go index b26a749ed7f..c7c15276f22 100644 --- a/server/ipqueue.go +++ b/server/ipqueue.go @@ -14,6 +14,7 @@ package server import ( + "errors" "sync" "sync/atomic" ) @@ -28,36 +29,72 @@ type ipQueue[T any] struct { elts []T pos int pool *sync.Pool - mrs int + sz uint64 // Calculated size (only if calc != nil) name string m *sync.Map + ipQueueOpts[T] } -type ipQueueOpts struct { - maxRecycleSize int +type ipQueueOpts[T any] struct { + mrs int // Max recycle size + calc func(e T) uint64 // Calc function for tracking size + msz uint64 // Limit by total calculated size + mlen int // Limit by number of entries } -type ipQueueOpt func(*ipQueueOpts) +type ipQueueOpt[T any] func(*ipQueueOpts[T]) // This option allows to set the maximum recycle size when attempting // to put back a slice to the pool. -func ipQueue_MaxRecycleSize(max int) ipQueueOpt { - return func(o *ipQueueOpts) { - o.maxRecycleSize = max +func ipqMaxRecycleSize[T any](max int) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.mrs = max } } -func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] { - qo := ipQueueOpts{maxRecycleSize: ipQueueDefaultMaxRecycleSize} - for _, o := range opts { - o(&qo) +// This option enables total queue size counting by passing in a function +// that evaluates the size of each entry as it is pushed/popped. This option +// enables the size() function. +func ipqSizeCalculation[T any](calc func(e T) uint64) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.calc = calc + } +} + +// This option allows setting the maximum queue size. Once the limit is +// reached, then push() will stop returning true and no more entries will +// be stored until some more are popped. The ipQueue_SizeCalculation must +// be provided for this to work. +func ipqLimitBySize[T any](max uint64) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.msz = max + } +} + +// This option allows setting the maximum queue length. Once the limit is +// reached, then push() will stop returning true and no more entries will +// be stored until some more are popped. +func ipqLimitByLen[T any](max int) ipQueueOpt[T] { + return func(o *ipQueueOpts[T]) { + o.mlen = max } +} + +var errIPQLenLimitReached = errors.New("IPQ len limit reached") +var errIPQSizeLimitReached = errors.New("IPQ size limit reached") + +func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt[T]) *ipQueue[T] { q := &ipQueue[T]{ ch: make(chan struct{}, 1), - mrs: qo.maxRecycleSize, pool: &sync.Pool{}, name: name, m: &s.ipQueues, + ipQueueOpts: ipQueueOpts[T]{ + mrs: ipQueueDefaultMaxRecycleSize, + }, + } + for _, o := range opts { + o(&q.ipQueueOpts) } s.ipQueues.Store(name, q) return q @@ -66,10 +103,14 @@ func newIPQueue[T any](s *Server, name string, opts ...ipQueueOpt) *ipQueue[T] { // Add the element `e` to the queue, notifying the queue channel's `ch` if the // entry is the first to be added, and returns the length of the queue after // this element is added. -func (q *ipQueue[T]) push(e T) int { +func (q *ipQueue[T]) push(e T) (int, error) { var signal bool q.Lock() l := len(q.elts) - q.pos + if q.mlen > 0 && l == q.mlen { + q.Unlock() + return l, errIPQLenLimitReached + } if l == 0 { signal = true eltsi := q.pool.Get() @@ -82,8 +123,15 @@ func (q *ipQueue[T]) push(e T) int { q.elts = make([]T, 0, 32) } } + if q.calc != nil { + sz := q.calc(e) + if q.msz > 0 && q.sz+sz > q.msz { + q.Unlock() + return l, errIPQSizeLimitReached + } + q.sz += sz + } q.elts = append(q.elts, e) - l++ q.Unlock() if signal { select { @@ -91,7 +139,7 @@ func (q *ipQueue[T]) push(e T) int { default: } } - return l + return l + 1, nil } // Returns the whole list of elements currently present in the queue, @@ -116,6 +164,11 @@ func (q *ipQueue[T]) pop() []T { } q.elts, q.pos = nil, 0 atomic.AddInt64(&q.inprogress, int64(len(elts))) + if q.calc != nil { + for _, e := range elts { + q.sz -= q.calc(e) + } + } q.Unlock() return elts } @@ -140,6 +193,9 @@ func (q *ipQueue[T]) popOne() (T, bool) { } e := q.elts[q.pos] q.pos++ + if q.calc != nil { + q.sz -= q.calc(e) + } l-- if l > 0 { // We need to re-signal @@ -184,9 +240,16 @@ func (q *ipQueue[T]) recycle(elts *[]T) { // Returns the current length of the queue. func (q *ipQueue[T]) len() int { q.Lock() - l := len(q.elts) - q.pos - q.Unlock() - return l + defer q.Unlock() + return len(q.elts) - q.pos +} + +// Returns the calculated size of the queue (if ipQueue_SizeCalculation has been +// passed in), otherwise returns zero. +func (q *ipQueue[T]) size() uint64 { + q.Lock() + defer q.Unlock() + return q.sz } // Empty the queue and consumes the notification signal if present. @@ -202,6 +265,7 @@ func (q *ipQueue[T]) drain() { q.resetAndReturnToPool(&q.elts) q.elts, q.pos = nil, 0 } + q.sz = 0 // Consume the signal if it was present to reduce the chance of a reader // routine to be think that there is something in the queue... select { diff --git a/server/ipqueue_test.go b/server/ipqueue_test.go index 5034a246663..e5fb96825da 100644 --- a/server/ipqueue_test.go +++ b/server/ipqueue_test.go @@ -42,7 +42,7 @@ func TestIPQueueBasic(t *testing.T) { } // Try to change the max recycle size - q2 := newIPQueue[int](s, "test2", ipQueue_MaxRecycleSize(10)) + q2 := newIPQueue[int](s, "test2", ipqMaxRecycleSize[int](10)) if q2.mrs != 10 { t.Fatalf("Expected max recycle size to be 10, got %v", q2.mrs) } @@ -290,7 +290,7 @@ func TestIPQueueRecycle(t *testing.T) { for iter := 0; iter < 5; iter++ { var sz int for i := 0; i < total; i++ { - sz = q.push(i) + sz, _ = q.push(i) } if sz != total { t.Fatalf("Expected size to be %v, got %v", total, sz) @@ -298,7 +298,7 @@ func TestIPQueueRecycle(t *testing.T) { values := q.pop() preRecycleCap := cap(values) q.recycle(&values) - sz = q.push(1001) + sz, _ = q.push(1001) if sz != 1 { t.Fatalf("Expected size to be %v, got %v", 1, sz) } @@ -317,7 +317,7 @@ func TestIPQueueRecycle(t *testing.T) { } } - q = newIPQueue[int](s, "test2", ipQueue_MaxRecycleSize(10)) + q = newIPQueue[int](s, "test2", ipqMaxRecycleSize[int](10)) for i := 0; i < 100; i++ { q.push(i) } @@ -389,3 +389,105 @@ func TestIPQueueDrain(t *testing.T) { } } } + +func TestIPQueueSizeCalculation(t *testing.T) { + type testType = [16]byte + var testValue testType + + calc := ipqSizeCalculation[testType](func(e testType) uint64 { + return uint64(len(e)) + }) + s := &Server{} + q := newIPQueue[testType](s, "test", calc) + + for i := 0; i < 10; i++ { + q.push(testValue) + require_Equal(t, q.len(), i+1) + require_Equal(t, q.size(), uint64(i+1)*uint64(len(testValue))) + } + + for i := 10; i > 5; i-- { + q.popOne() + require_Equal(t, q.len(), i-1) + require_Equal(t, q.size(), uint64(i-1)*uint64(len(testValue))) + } + + q.pop() + require_Equal(t, q.len(), 0) + require_Equal(t, q.size(), 0) +} + +func TestIPQueueSizeCalculationWithLimits(t *testing.T) { + type testType = [16]byte + var testValue testType + + calc := ipqSizeCalculation[testType](func(e testType) uint64 { + return uint64(len(e)) + }) + s := &Server{} + + t.Run("LimitByLen", func(t *testing.T) { + q := newIPQueue[testType](s, "test", calc, ipqLimitByLen[testType](5)) + for i := 0; i < 10; i++ { + n, err := q.push(testValue) + if i >= 5 { + require_Error(t, err, errIPQLenLimitReached) + } else { + require_NoError(t, err) + } + require_LessThan(t, n, 6) + } + }) + + t.Run("LimitBySize", func(t *testing.T) { + q := newIPQueue[testType](s, "test", calc, ipqLimitBySize[testType](16*5)) + for i := 0; i < 10; i++ { + n, err := q.push(testValue) + if i >= 5 { + require_Error(t, err, errIPQSizeLimitReached) + } else { + require_NoError(t, err) + } + require_LessThan(t, n, 6) + } + }) +} + +func BenchmarkIPQueueSizeCalculation(b *testing.B) { + type testType = [16]byte + var testValue testType + + s := &Server{} + + run := func(b *testing.B, q *ipQueue[testType]) { + b.SetBytes(16) + for i := 0; i < b.N; i++ { + q.push(testValue) + } + for i := b.N; i > 0; i-- { + q.popOne() + } + } + + // Measures without calculation function overheads. + b.Run("WithoutCalc", func(b *testing.B) { + run(b, newIPQueue[testType](s, "test")) + }) + + // Measures the raw overhead of having a calculation function. + b.Run("WithEmptyCalc", func(b *testing.B) { + calc := ipqSizeCalculation[testType](func(e testType) uint64 { + return 0 + }) + run(b, newIPQueue[testType](s, "test", calc)) + }) + + // Measures the overhead of having a calculation function that + // actually measures something useful. + b.Run("WithLenCalc", func(b *testing.B) { + calc := ipqSizeCalculation[testType](func(e testType) uint64 { + return uint64(len(e)) + }) + run(b, newIPQueue[testType](s, "test", calc)) + }) +} diff --git a/server/jetstream_api.go b/server/jetstream_api.go index f33c4b3953f..9558ae49fe7 100644 --- a/server/jetstream_api.go +++ b/server/jetstream_api.go @@ -863,7 +863,7 @@ func (js *jetStream) apiDispatch(sub *subscription, c *client, acc *Account, sub // header from the msg body. No other references are needed. // Check pending and warn if getting backed up. const warnThresh = 128 - pending := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa}) + pending, _ := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa}) if pending >= warnThresh { s.rateLimitFormatWarnf("JetStream request queue has high pending count: %d", pending) }