Skip to content

Commit

Permalink
feat: implement different rotation strategies based on config
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronhnsy committed Jul 26, 2022
1 parent f564b38 commit bae1779
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 47 deletions.
19 changes: 12 additions & 7 deletions swish/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@
import yt_dlp

from .config import CONFIG
from .rotator import IpRotator
from .player import Player
from .types.payloads import *


LOG: logging.Logger = logging.getLogger('swish.app')
from .rotator import BanRotator, NanosecondRotator
from .types.payloads import ReceivedPayload


__all__ = (
'App',
)


LOG: logging.Logger = logging.getLogger('swish.app')


class App(aiohttp.web.Application):

def __init__(self) -> None:
Expand Down Expand Up @@ -188,10 +188,16 @@ def _decode_track_id(_id: str, /) -> dict[str, Any]:
'none': ''
}

_ROTATOR_MAPPING: dict[str, type[NanosecondRotator] | type[BanRotator]] = {
'nanosecond-rotator': NanosecondRotator,
'ban-rotator': BanRotator
}

async def _ytdl_search(self, query: str, internal: bool) -> Any:

self._SEARCH_OPTIONS['source_address'] = IpRotator.rotate()
self._SEARCH_OPTIONS['extract_flat'] = not internal
if CONFIG.rotation.enabled:
self._SEARCH_OPTIONS['source_address'] = self._ROTATOR_MAPPING[CONFIG.rotation.method].rotate()

with yt_dlp.YoutubeDL(self._SEARCH_OPTIONS) as YTDL:
with contextlib.redirect_stdout(open(os.devnull, 'w')):
Expand All @@ -216,7 +222,6 @@ async def _get_tracks(self, query: str) -> list[dict[str, Any]]:
tracks: list[dict[str, Any]] = []

for entry in entries:

info: dict[str, Any] = {
'title': entry['title'],
'identifier': entry['id'],
Expand Down
9 changes: 5 additions & 4 deletions swish/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
'CONFIG',
)

DEFAULT: dict[str, Any] = {

DEFAULT_CONFIG: dict[str, Any] = {
'server': {
'host': '127.0.0.1',
'port': 8000,
Expand Down Expand Up @@ -96,15 +97,15 @@ class Config:


try:
CONFIG: Config = dacite.from_dict(Config, toml.load('swish.toml'))
CONFIG: Config = dacite.from_dict(Config, toml.load('../swish.toml'))

except (toml.TomlDecodeError, FileNotFoundError):

with open('swish.toml', 'w') as fp:
toml.dump(DEFAULT, fp)
toml.dump(DEFAULT_CONFIG, fp)

print('Could not find or parse swish.toml, using default configuration values.')
CONFIG: Config = dacite.from_dict(Config, DEFAULT)
CONFIG: Config = dacite.from_dict(Config, DEFAULT_CONFIG)


except dacite.DaciteError as error:
Expand Down
21 changes: 17 additions & 4 deletions swish/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,29 @@

import aiohttp
import aiohttp.web
from discord.backoff import ExponentialBackoff
import discord.backoff
from discord.ext.native_voice import native_voice # type: ignore

from .types.payloads import *

from .types.payloads import (
PayloadHandlers,
ReceivedPayload,
SentPayloadOp,
VoiceUpdateData,
PlayData,
SetPauseStateData,
SetPositionData,
SetFilterData,
)

if TYPE_CHECKING:
from .app import App


__all__ = (
'Player',
)


LOG: logging.Logger = logging.getLogger('swish.player')


Expand Down Expand Up @@ -105,7 +118,7 @@ async def _connect(self) -> None:
async def _reconnect_handler(self) -> None:

loop = asyncio.get_running_loop()
backoff = ExponentialBackoff()
backoff = discord.backoff.ExponentialBackoff()

while True:

Expand Down
102 changes: 70 additions & 32 deletions swish/rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,66 +19,104 @@
from __future__ import annotations

import ipaddress
import itertools
import logging
import random
import time

import discord.utils
from collections.abc import Iterator

from .config import CONFIG
from .utilities import plural


__all__ = (
'BaseRotator',
'NanosecondRotator',
'BanRotator',
)


LOG: logging.Logger = logging.getLogger('swish.rotator')


Network = ipaddress.IPv4Network | ipaddress.IPv6Network
IP = ipaddress.IPv4Address | ipaddress.IPv6Address


class IpRotator:
class BaseRotator:

_blocks = CONFIG['IP']['blocks']
_networks: list[Network] = [ipaddress.ip_network(ip) for ip in _blocks]
_enabled: bool
_networks: list[Network]
_address_count: int

_cycle: Iterator[Network]
_current_network: Network

if CONFIG.rotation.blocks:
_enabled = True
_networks = [ipaddress.ip_network(block) for block in CONFIG.rotation.blocks]
_address_count = sum(network.num_addresses for network in _networks)
LOG.info(
f'IP rotation enabled using {plural(_address_count, "IP")} from {plural(len(_networks), "network block")}.'
)
_cycle = itertools.cycle(_networks)
_current_network = next(_cycle)

if _networks:
_total: int = sum(network.num_addresses for network in _networks)
LOG.info(f'IP rotation enabled using {_total} total addresses.')
else:
_total: int = 0
LOG.warning('No IP blocks configured. Increased risk of rate-limiting.')
_enabled = False
_networks = []
_address_count = 0
_cycle = discord.utils.MISSING
_current_network = discord.utils.MISSING

_banned: list[IP] = []
_current: IP | None = None
LOG.warning('No network blocks configured, increased risk of ratelimiting.')

_ns = time.time_ns()
@classmethod
def rotate(cls) -> ...:
raise NotImplementedError


class BanRotator(BaseRotator):

_offset: int = 0

@classmethod
def rotate(cls) -> str:

if not cls._networks:
if not cls._enabled:
return '0.0.0.0'

# TODO: Only ban on 429
"""if cls._current:
cls._banned.append(cls._current)
LOG.debug(f'Excluded IP: {cls._current}')"""
if cls._offset >= cls._current_network.num_addresses:
cls._current_network = next(cls._cycle)
cls._offset = 0

address = cls._current_network[cls._offset]
cls._offset += 1

return str(address)

net = random.choice(cls._networks)
if net.prefixlen == 128:

class NanosecondRotator(BaseRotator):

_ns: int = time.time_ns()

@classmethod
def rotate(cls) -> str:

if not cls._enabled or cls._address_count < 2 ** 64:
return '0.0.0.0'

while True:
NSOFFSET = time.time_ns() - cls._ns

if NSOFFSET > cls._total:
cls._ns = time.time_ns()
continue
offset = time.time_ns() - cls._ns

ip = net[NSOFFSET]
if ip == cls._current or ip in cls._banned:
if offset > cls._address_count:
cls._ns = time.time_ns()
continue
elif offset >= cls._current_network.num_addresses:
cls._current_network = next(cls._cycle)
offset -= cls._current_network.num_addresses
else:
break

# WARNING: Very verbose...
# LOG.info(f'Rotated to new IP: {ip}')
cls._current = ip
break

return str(cls._current)
return str(cls._current_network[offset])
29 changes: 29 additions & 0 deletions swish/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Swish. A standalone audio player and server for bots on Discord.
Copyright (C) 2022 PythonistaGuild <https://github.com/PythonistaGuild>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""

from __future__ import annotations

from collections.abc import Callable


__all__ = (
'plural',
)


plural: Callable[[int, str], str] = lambda count, thing: f'{count} {thing}s' if count > 1 else f'{count} {thing}'

0 comments on commit bae1779

Please sign in to comment.