generated from denpeshkov/go-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathswcounter.go
126 lines (109 loc) · 2.98 KB
/
swcounter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package throttle
import (
"context"
"crypto/sha1" //nolint:gosec
_ "embed"
"encoding/hex"
"io"
"strconv"
"strings"
"sync"
"time"
)
//go:embed swcounter.lua
var swCounterScript string
// SWCounterLimiter implements a rate limiter using the Sliding-Window Counter algorithm.
// It works with a 1ms resolution.
type SWCounterLimiter struct {
rds Rediser
sha string
keyTTL time.Duration
clock func() time.Time
mu sync.Mutex
lim Limit
}
// NewSWCounterLimiter returns a new configured [SWCounterLimiter].
func NewSWCounterLimiter(rds Rediser, limit Limit, opts ...Option) (*SWCounterLimiter, error) {
if err := limit.Valid(); err != nil {
return nil, err
}
options := options{
keyTTL: 1 * time.Second,
clock: time.Now,
}
for _, o := range opts {
o.apply(&options)
}
h := sha1.New() //nolint:gosec
_, _ = io.WriteString(h, swCounterScript)
return &SWCounterLimiter{
rds: rds,
sha: hex.EncodeToString(h.Sum(nil)),
keyTTL: options.keyTTL,
clock: options.clock,
lim: limit,
}, nil
}
// Allow determines whether the event for the specified key is permitted at the current time.
func (l *SWCounterLimiter) Allow(ctx context.Context, key string) (Status, error) {
l.mu.Lock()
lim := l.lim
now := l.clock().UTC().UnixMilli()
interval := lim.Interval.Milliseconds()
prevWindow, curWindow := strconv.FormatInt((now-interval)/interval, 10), strconv.FormatInt(now/interval, 10)
ttl := max(2*lim.Interval+l.keyTTL, 2*lim.Interval)
l.mu.Unlock()
if lim.Events == 0 {
return Status{Limited: true, Remaining: 0, Delay: Inf}, nil
}
keys := []string{key, prevWindow, curWindow}
args := []any{lim.Events, interval, now, ttl.Milliseconds()}
v, err := l.execScript(ctx, keys, args)
if err != nil {
return Status{}, err
}
values := v.([]interface{})
return Status{
Limited: values[0].(int64) != 0,
Remaining: int(values[1].(int64)),
Delay: time.Duration(values[2].(int64)) * time.Millisecond,
}, nil
}
// Limit returns the current limit.
func (l *SWCounterLimiter) Limit() Limit {
l.mu.Lock()
defer l.mu.Unlock()
return l.lim
}
// SetLimit sets a new limit.
func (l *SWCounterLimiter) SetLimit(ctx context.Context, newLimit Limit) error {
if err := newLimit.Valid(); err != nil {
return err
}
l.mu.Lock()
defer l.mu.Unlock()
l.lim = newLimit
return nil
}
// Reset clears all limitations and previous usage for the specified keys.
// If no keys are provided, it's a no-op.
func (l *SWCounterLimiter) Reset(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
_, err := l.rds.Del(ctx, keys...)
return err
}
func (l *SWCounterLimiter) execScript(ctx context.Context, keys []string, args ...any) (any, error) {
v, err := l.rds.EvalSHA(ctx, l.sha, keys, args...)
if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT") {
if _, err := l.rds.ScriptLoad(ctx, swCounterScript); err != nil {
return nil, err
}
v, err = l.rds.EvalSHA(ctx, l.sha, keys, args...)
}
if err != nil {
return nil, err
}
return v, nil
}