Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple cooldowns #470

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 83 additions & 64 deletions twitchio/ext/commands/cooldowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,31 @@ class Bucket(enum.Enum):
The default bucket.
channel: :class:`enum.Enum`
Cooldown is shared amongst all chatters per channel.
user: :class:`enum.Enum`
Cooldown operates on a per user basis across all channels.
member: :class:`enum.Enum`
Cooldown operates on a per channel basis per user.
user: :class:`enum.Enum`
Cooldown operates on a user basis across all channels.
turbo: :class:`enum.Enum`
Cooldown for turbo users.
subscriber: :class:`enum.Enum`
Cooldown for subscribers.
vip: :class:`enum.Enum`
Cooldown for VIPs.
mod: :class:`enum.Enum`
Cooldown for mods.
broadcaster: :class:`enum.Enum`
Cooldown for the broadcaster.
"""

default = 0
channel = 1
member = 2
user = 3
subscriber = 4
mod = 5
user = 2
member = 3
turbo = 4
subscriber = 5
vip = 6
mod = 7
broadcaster = 8


class Cooldown:
Expand Down Expand Up @@ -100,80 +109,90 @@ async def my_command(self, ctx: commands.Context):
@commands.command()
async def my_command(self, ctx: commands.Context):
pass

# Restrict a command to 5 times every 60 seconds globally for a user,
# 5 times every 30 seconds if the user is turbo,
# and 1 time every 1 second if they're the channel broadcaster
@commands.cooldown(rate=5, per=60, bucket=commands.Bucket.user)
@commands.cooldown(rate=5, per=30, bucket=commands.Bucket.turbo)
@commands.cooldown(rate=1, per=1, bucket=commands.Bucket.broadcaster)
@commands.command()
async def my_command(self, ctx: commands.Context):
pass
"""

__slots__ = ("_rate", "_per", "bucket", "_window", "_tokens", "_cache")
__slots__ = ("_rate", "_per", "bucket", "_cache")

def __init__(self, rate: int, per: float, bucket: Bucket):
def __init__(self, rate: int, per: float, bucket: Bucket) -> None:
self._rate = rate
self._per = per
self.bucket = bucket

self._cache = {}

def update_bucket(self, ctx):
now = time.time()

bucket_keys = self._bucket_keys(ctx)
buckets = []

for bucket in bucket_keys:
(tokens, window) = self._cache[bucket]

if tokens == self._rate:
retry = self._per - (now - window)
raise CommandOnCooldown(command=ctx.command, retry_after=retry)
def _update_cooldown(self, bucket_key, now) -> int | None:
tokens = self._cache[bucket_key]

tokens += 1
if len(tokens) == self._rate:
retry = self._per - (now - tokens[0])
return retry

if tokens == self._rate:
window = now
tokens.append(now)

self._cache[bucket] = (tokens, window)

def reset(self):
def reset(self) -> None:
self._cache = {}

def _bucket_keys(self, ctx):
buckets = []

for bucket in ctx.command._cooldowns:
if bucket.bucket == Bucket.default:
buckets.append("default")

if bucket.bucket == Bucket.channel:
buckets.append(ctx.channel.name)

if bucket.bucket == Bucket.member:
buckets.append((ctx.channel.name, ctx.author.id))
if bucket.bucket == Bucket.user:
buckets.append(ctx.author.id)

if bucket.bucket == Bucket.subscriber:
buckets.append((ctx.channel.name, ctx.author.id, 0))
if bucket.bucket == Bucket.mod:
buckets.append((ctx.channel.name, ctx.author.id, 1))

return buckets

def _update_cache(self, now=None):
now = now or time.time()
dead = [key for key, cooldown in self._cache.items() if now > cooldown[1] + self._per]

for bucket in dead:
del self._cache[bucket]

def get_buckets(self, ctx):
def _bucket_key(self, ctx):
key = None

if self.bucket == Bucket.default:
key = "default"
elif self.bucket == Bucket.channel:
key = ctx.channel.name
elif self.bucket == Bucket.user:
key = ctx.author.id
elif self.bucket == Bucket.member:
key = (ctx.channel.name, ctx.author.id)
elif self.bucket == Bucket.turbo and ctx.author.is_turbo:
key = (ctx.channel.name, ctx.author.id)
elif self.bucket == Bucket.subscriber and ctx.author.is_subscriber:
key = (ctx.channel.name, ctx.author.id)
elif self.bucket == Bucket.vip and ctx.author.is_vip:
key = (ctx.channel.name, ctx.author.id)
elif self.bucket == Bucket.mod and ctx.author.is_mod:
key = (ctx.channel.name, ctx.author.id)
elif self.bucket == Bucket.broadcaster and ctx.author.is_broadcaster:
key = (ctx.channel.name, ctx.author.id)

return key

def _update_cache(self, now) -> None:
expired_bucket_keys = []

for bucket_key, tokens in self._cache.items():
expired_tokens = []

for token in tokens:
if now - token > self._per:
expired_tokens.append(token)

for expired_token in expired_tokens:
tokens.remove(expired_token)

if not tokens:
expired_bucket_keys.append(bucket_key)

for expired_bucket_key in expired_bucket_keys:
del self._cache[expired_bucket_key]

def on_cooldown(self, ctx) -> int | None:
now = time.time()

self._update_cache(now)

bucket_keys = self._bucket_keys(ctx)
buckets = []

for index, bucket in enumerate(bucket_keys):
buckets.append(ctx.command._cooldowns[index])
if bucket not in self._cache:
self._cache[bucket] = (0, now)
bucket_key = self._bucket_key(ctx)
if bucket_key:
if not bucket_key in self._cache:
self._cache[bucket_key] = []

return buckets
return self._update_cooldown(bucket_key, now)
22 changes: 10 additions & 12 deletions twitchio/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ async def try_run(func, *, to_command=False):
limited = self._run_cooldowns(context)

if limited:
context.bot.run_event("command_error", context, limited[0])
e = CommandOnCooldown(command=context.command, retry_after=limited)
context.bot.run_event("command_error", context, e)
return
instance = self._instance
args = [instance, context] if instance else [context]
Expand All @@ -377,19 +378,16 @@ async def try_run(func, *, to_command=False):
await try_run(self._after_invoke(*args), to_command=True)
await try_run(context.bot.global_after_invoke(context))

def _run_cooldowns(self, context: Context) -> Optional[List[CommandOnCooldown]]:
try:
buckets = self._cooldowns[0].get_buckets(context)
except IndexError:
def _run_cooldowns(self, context: Context) -> Optional[int]:
if not self._cooldowns:
return None
expired = []

try:
for bucket in buckets:
bucket.update_bucket(context)
except CommandOnCooldown as e:
expired.append(e)
return expired
retries = []
for c in self._cooldowns:
retry = c.on_cooldown(context)
retries.append(retry)
if all(retries):
return min(retries)

async def handle_checks(self, context: Context) -> Union[Literal[True], Exception]:
# TODO Docs
Expand Down