From be4a78c665f16aec1a78abc253ad1c85e907c8b6 Mon Sep 17 00:00:00 2001 From: Bart Moorman Date: Sun, 22 Sep 2024 07:51:35 -0600 Subject: [PATCH 1/3] Support multiple cooldowns (#1) * Add support for multiple cooldowns per command --- twitchio/ext/commands/cooldowns.py | 143 +++++++++++++++++------------ twitchio/ext/commands/core.py | 22 ++--- 2 files changed, 92 insertions(+), 73 deletions(-) diff --git a/twitchio/ext/commands/cooldowns.py b/twitchio/ext/commands/cooldowns.py index ded01efc..81ee63ec 100644 --- a/twitchio/ext/commands/cooldowns.py +++ b/twitchio/ext/commands/cooldowns.py @@ -47,22 +47,31 @@ class Bucket(enum.Enum): The default bucket. channel: :class:`enum.Enum` Cooldown is shared amongst all chatters per channel. + user: :class:`enum.Enum` + Cooldown operates on a per user basis across all channels. member: :class:`enum.Enum` Cooldown operates on a per channel basis per user. - user: :class:`enum.Enum` - Cooldown operates on a user basis across all channels. + turbo: :class:`enum.Enum` + Cooldown for turbo users. subscriber: :class:`enum.Enum` Cooldown for subscribers. + vip: :class:`enum.Enum` + Cooldown for VIPs. mod: :class:`enum.Enum` Cooldown for mods. + broadcaster: :class:`enum.Enum` + Cooldown for the broadcaster. """ default = 0 channel = 1 - member = 2 - user = 3 - subscriber = 4 - mod = 5 + user = 2 + member = 3 + turbo = 4 + subscriber = 5 + vip = 6 + mod = 7 + broadcaster = 8 class Cooldown: @@ -100,80 +109,92 @@ async def my_command(self, ctx: commands.Context): @commands.command() async def my_command(self, ctx: commands.Context): pass + + # Restrict a command to 5 times every 60 seconds globally for a user, + # 5 times every 30 seconds if the user is turbo, + # and 1 time every 1 second if they're the channel broadcaster + @commands.cooldown(rate=5, per=60, bucket=commands.Bucket.user) + @commands.cooldown(rate=5, per=30, bucket=commands.Bucket.turbo) + @commands.cooldown(rate=1, per=1, bucket=commands.Bucket.broadcaster) + @commands.command() + async def my_command(self, ctx: commands.Context): + pass """ - __slots__ = ("_rate", "_per", "bucket", "_window", "_tokens", "_cache") + __slots__ = ("_rate", "_per", "bucket", "_cache") - def __init__(self, rate: int, per: float, bucket: Bucket): + def __init__(self, rate: int, per: float, bucket: Bucket) -> None: self._rate = rate self._per = per self.bucket = bucket self._cache = {} - def update_bucket(self, ctx): - now = time.time() - - bucket_keys = self._bucket_keys(ctx) - buckets = [] + def update_cooldown(self, key, now) -> int | None: + cooldown = self._cache[key] - for bucket in bucket_keys: - (tokens, window) = self._cache[bucket] + if cooldown["tokens"] == self._rate: + retry = self._per - (now - cooldown["start_time"]) + return retry - if tokens == self._rate: - retry = self._per - (now - window) - raise CommandOnCooldown(command=ctx.command, retry_after=retry) + if cooldown["tokens"] == 1 and self._rate > 1: + cooldown["next_start_time"] = now - tokens += 1 + cooldown["tokens"] += 1 - if tokens == self._rate: - window = now + if cooldown["tokens"] == self._rate and not self._rate == 1: + cooldown["start_time"] = cooldown["next_start_time"] - self._cache[bucket] = (tokens, window) + self._cache[key] = cooldown - def reset(self): + def reset(self) -> None: self._cache = {} - def _bucket_keys(self, ctx): - buckets = [] - - for bucket in ctx.command._cooldowns: - if bucket.bucket == Bucket.default: - buckets.append("default") - - if bucket.bucket == Bucket.channel: - buckets.append(ctx.channel.name) - - if bucket.bucket == Bucket.member: - buckets.append((ctx.channel.name, ctx.author.id)) - if bucket.bucket == Bucket.user: - buckets.append(ctx.author.id) - - if bucket.bucket == Bucket.subscriber: - buckets.append((ctx.channel.name, ctx.author.id, 0)) - if bucket.bucket == Bucket.mod: - buckets.append((ctx.channel.name, ctx.author.id, 1)) - - return buckets - - def _update_cache(self, now=None): - now = now or time.time() - dead = [key for key, cooldown in self._cache.items() if now > cooldown[1] + self._per] - - for bucket in dead: - del self._cache[bucket] - - def get_buckets(self, ctx): + def _key(self, ctx): + key = None + + if self.bucket == Bucket.default: + key = "default" + elif self.bucket == Bucket.channel: + key = ctx.channel.name + elif self.bucket == Bucket.user: + key = ctx.author.id + elif self.bucket == Bucket.member: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.turbo and ctx.author.is_turbo: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.subscriber and ctx.author.is_subscriber: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.vip and ctx.author.is_vip: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.mod and ctx.author.is_mod: + key = (ctx.channel.name, ctx.author.id) + elif self.bucket == Bucket.broadcaster and ctx.author.is_broadcaster: + key = (ctx.channel.name, ctx.author.id) + + return key + + def _update_cache(self, now) -> None: + expired = [] + for key, cooldown in self._cache.items(): + if now > cooldown["start_time"] + self._per: + if cooldown["tokens"] > 1: + cooldown["tokens"] -= 1 + cooldown["start_time"] = cooldown["next_start_time"] + cooldown["next_start_time"] = now + else: + expired.append(key) + for key in expired: + del self._cache[key] + + def on_cooldown(self, ctx) -> None: now = time.time() self._update_cache(now) - bucket_keys = self._bucket_keys(ctx) - buckets = [] - - for index, bucket in enumerate(bucket_keys): - buckets.append(ctx.command._cooldowns[index]) - if bucket not in self._cache: - self._cache[bucket] = (0, now) + key = self._key(ctx) + if key: + if not key in self._cache: + self._cache[key] = {"tokens": 0, "start_time": now, "next_start_time": None} - return buckets + return self.update_cooldown(key, now) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index 2b40bdfb..9337eb31 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -356,7 +356,8 @@ async def try_run(func, *, to_command=False): limited = self._run_cooldowns(context) if limited: - context.bot.run_event("command_error", context, limited[0]) + e = CommandOnCooldown(command=context.command, retry_after=limited) + context.bot.run_event("command_error", context, e) return instance = self._instance args = [instance, context] if instance else [context] @@ -377,19 +378,16 @@ async def try_run(func, *, to_command=False): await try_run(self._after_invoke(*args), to_command=True) await try_run(context.bot.global_after_invoke(context)) - def _run_cooldowns(self, context: Context) -> Optional[List[CommandOnCooldown]]: - try: - buckets = self._cooldowns[0].get_buckets(context) - except IndexError: + def _run_cooldowns(self, context: Context) -> Optional[int]: + if not self._cooldowns: return None - expired = [] - try: - for bucket in buckets: - bucket.update_bucket(context) - except CommandOnCooldown as e: - expired.append(e) - return expired + retries = [] + for c in self._cooldowns: + retry = c.on_cooldown(context) + retries.append(retry) + if all(retries): + return min(retries) async def handle_checks(self, context: Context) -> Union[Literal[True], Exception]: # TODO Docs From 978aa1eadeea3f9fd8a9a4de91656105b4e73554 Mon Sep 17 00:00:00 2001 From: Bart Moorman Date: Sun, 22 Sep 2024 16:37:42 -0600 Subject: [PATCH 2/3] Make the Cooldown.update_cooldown() private --- twitchio/ext/commands/cooldowns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/twitchio/ext/commands/cooldowns.py b/twitchio/ext/commands/cooldowns.py index 81ee63ec..21dc670c 100644 --- a/twitchio/ext/commands/cooldowns.py +++ b/twitchio/ext/commands/cooldowns.py @@ -130,7 +130,7 @@ def __init__(self, rate: int, per: float, bucket: Bucket) -> None: self._cache = {} - def update_cooldown(self, key, now) -> int | None: + def _update_cooldown(self, key, now) -> int | None: cooldown = self._cache[key] if cooldown["tokens"] == self._rate: @@ -197,4 +197,4 @@ def on_cooldown(self, ctx) -> None: if not key in self._cache: self._cache[key] = {"tokens": 0, "start_time": now, "next_start_time": None} - return self.update_cooldown(key, now) + return self._update_cooldown(key, now) From 83647129e653141da7ee0e286db73ff08859d5c9 Mon Sep 17 00:00:00 2001 From: Bart Moorman Date: Mon, 23 Sep 2024 12:30:24 -0600 Subject: [PATCH 3/3] Fix logic when rate > 1 --- twitchio/ext/commands/cooldowns.py | 62 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/twitchio/ext/commands/cooldowns.py b/twitchio/ext/commands/cooldowns.py index 21dc670c..96c6cea9 100644 --- a/twitchio/ext/commands/cooldowns.py +++ b/twitchio/ext/commands/cooldowns.py @@ -130,27 +130,19 @@ def __init__(self, rate: int, per: float, bucket: Bucket) -> None: self._cache = {} - def _update_cooldown(self, key, now) -> int | None: - cooldown = self._cache[key] + def _update_cooldown(self, bucket_key, now) -> int | None: + tokens = self._cache[bucket_key] - if cooldown["tokens"] == self._rate: - retry = self._per - (now - cooldown["start_time"]) + if len(tokens) == self._rate: + retry = self._per - (now - tokens[0]) return retry - if cooldown["tokens"] == 1 and self._rate > 1: - cooldown["next_start_time"] = now - - cooldown["tokens"] += 1 - - if cooldown["tokens"] == self._rate and not self._rate == 1: - cooldown["start_time"] = cooldown["next_start_time"] - - self._cache[key] = cooldown + tokens.append(now) def reset(self) -> None: self._cache = {} - def _key(self, ctx): + def _bucket_key(self, ctx): key = None if self.bucket == Bucket.default: @@ -175,26 +167,32 @@ def _key(self, ctx): return key def _update_cache(self, now) -> None: - expired = [] - for key, cooldown in self._cache.items(): - if now > cooldown["start_time"] + self._per: - if cooldown["tokens"] > 1: - cooldown["tokens"] -= 1 - cooldown["start_time"] = cooldown["next_start_time"] - cooldown["next_start_time"] = now - else: - expired.append(key) - for key in expired: - del self._cache[key] - - def on_cooldown(self, ctx) -> None: + expired_bucket_keys = [] + + for bucket_key, tokens in self._cache.items(): + expired_tokens = [] + + for token in tokens: + if now - token > self._per: + expired_tokens.append(token) + + for expired_token in expired_tokens: + tokens.remove(expired_token) + + if not tokens: + expired_bucket_keys.append(bucket_key) + + for expired_bucket_key in expired_bucket_keys: + del self._cache[expired_bucket_key] + + def on_cooldown(self, ctx) -> int | None: now = time.time() self._update_cache(now) - key = self._key(ctx) - if key: - if not key in self._cache: - self._cache[key] = {"tokens": 0, "start_time": now, "next_start_time": None} + bucket_key = self._bucket_key(ctx) + if bucket_key: + if not bucket_key in self._cache: + self._cache[bucket_key] = [] - return self._update_cooldown(key, now) + return self._update_cooldown(bucket_key, now)