Skip to content

Commit

Permalink
add pool back
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronhnsy committed Apr 28, 2022
1 parent fa623a9 commit 708c733
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 13 deletions.
1 change: 1 addition & 0 deletions slate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .node import *
from .objects import *
from .player import *
from .pool import *
from .queue import *
from .utils import *

Expand Down
15 changes: 15 additions & 0 deletions slate/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

__all__ = (
"SlateError",
"NodeAlreadyExists",
"NodeNotFound",
"NoNodesConnected",
"NodeConnectionError",
"InvalidNodePassword",
"NodeNotConnected",
Expand All @@ -26,6 +29,18 @@ class SlateError(Exception):
pass


class NodeAlreadyExists(SlateError):
pass


class NodeNotFound(SlateError):
pass


class NoNodesConnected(SlateError):
pass


class NodeConnectionError(SlateError):
pass

Expand Down
8 changes: 3 additions & 5 deletions slate/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
__all__ = (
"Node",
)

__log__: logging.Logger = logging.getLogger("slate.node")


class Node(Generic[BotT, ContextT, PlayerT]):
class Node(Generic[BotT, PlayerT]):
"""
Node's handle interactions between your bot and a provider server such as obsidian or lavalink. This includes
connecting to the websocket, searching for tracks, and managing player state.
Expand Down Expand Up @@ -274,7 +273,7 @@ async def _request(

session = await self._get_session()

url = f"{self._rest_url}{path}"
url = f"{self.rest_url}{path}"
headers = {
"Authorization": self.password,
"Client-Name": "Slate"
Expand Down Expand Up @@ -419,8 +418,7 @@ async def _search_other(

async def search(
self,
search: str,
/,
search: str, /,
*,
source: Source = Source.NONE,
ctx: ContextT | None = None,
Expand Down
18 changes: 10 additions & 8 deletions slate/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .objects.events import TrackEnd, TrackException, TrackStart, TrackStuck, WebsocketClosed, WebsocketOpen
from .objects.filters import Filter
from .objects.track import Track
from .pool import Pool
from .types import BotT, ContextT, VoiceChannel
from .utils import MISSING

Expand Down Expand Up @@ -50,27 +51,28 @@ class Player(discord.VoiceProtocol, Generic[BotT, ContextT]):
def __init__(
self,
client: BotT = MISSING,
channel: VoiceChannel = MISSING,
channel: VoiceChannel = MISSING, /,
*,
node: Node[BotT, ContextT, Self]
node: Node[BotT, Self] | None = None,
) -> None:
"""
Parameters
----------
node
The node this player should be attached to.
The node this player should be attached to, if :obj:`None` the player will be attached to the first node
found from the pool.
Warnings
--------
To connect to a voice channel you must construct an instance of this class, setting the ``node`` argument
(and extras, if subclassing) but **not** the ``client`` or ``channel`` arguments. You can then pass it to the ``cls``
argument of :meth:`discord.abc.Connectable.connect`.
(and extras, if subclassing) but **not** the ``client`` or ``channel`` arguments. You can then pass it to
the ``cls`` argument of :meth:`discord.abc.Connectable.connect`.
"""

self.client: BotT = client
self.channel: VoiceChannel = channel

self._node: Node[BotT, ContextT, Self] = node
self._node: Node[BotT, Self] = node or Pool.get_node() # type: ignore

self._voice_server_update_data: discord.types.voice.VoiceServerUpdate | None = None
self._session_id: str | None = None
Expand All @@ -88,7 +90,7 @@ def __init__(
def __call__(
self,
client: discord.Client,
channel: discord.abc.Connectable,
channel: discord.abc.Connectable, /,
) -> Self:

self.client = client
Expand Down Expand Up @@ -201,7 +203,7 @@ def voice_channel(self) -> VoiceChannel:
return self.channel

@property
def node(self) -> Node[BotT, ContextT, Self]:
def node(self) -> Node[BotT, Self]:
"""
The node this player is attached to.
"""
Expand Down
112 changes: 112 additions & 0 deletions slate/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Future
from __future__ import annotations

# Standard Library
import logging
from typing import Generic

# Packages
import aiohttp

# Local
from .exceptions import NodeAlreadyExists, NodeNotFound, NoNodesConnected
from .node import Node
from .objects.enums import Provider
from .types import BotT, JSONDumps, JSONLoads, PlayerT


__all__ = (
"Pool",
)
__log__: logging.Logger = logging.getLogger("slate.pool")


class Pool(Generic[BotT, PlayerT]):

def __repr__(self) -> str:
return f"<slate.Pool node_count={len(self.nodes)}>"

nodes: dict[str, Node[BotT, PlayerT]] = {}

@classmethod
async def create_node(
cls,
*,
bot: BotT,
session: aiohttp.ClientSession | None = None,
# Connection information
provider: Provider,
identifier: str,
host: str,
port: str,
password: str,
secure: bool = False,
resume_key: str | None = None,
# URLs
rest_url: str | None = None,
ws_url: str | None = None,
# JSON callables
json_dumps: JSONDumps | None = None,
json_loads: JSONLoads | None = None,
# Spotify
spotify_client_id: str | None = None,
spotify_client_secret: str | None = None,
) -> Node[BotT, PlayerT]:

if identifier in cls.nodes:
raise NodeAlreadyExists(f"A node with the identifier '{identifier}' already exists.")

node = Node(
bot=bot,
session=session,
provider=provider,
identifier=identifier,
host=host,
port=port,
password=password,
secure=secure,
resume_key=resume_key,
rest_url=rest_url,
ws_url=ws_url,
json_dumps=json_dumps,
json_loads=json_loads,
spotify_client_id=spotify_client_id,
spotify_client_secret=spotify_client_secret,
)
await node.connect()

cls.nodes[identifier] = node
__log__.info(f"Add node '{node.identifier}' to the pool.")

return node

@classmethod
def get_node(
cls,
identifier: str | None = None,
) -> Node[BotT, PlayerT]:

if not cls.nodes:
raise NoNodesConnected("There are no nodes connected.")

if not identifier:
return list(cls.nodes.values())[0]

if node := cls.nodes.get(identifier):
return node

raise NodeNotFound(f"A node with the identifier '{identifier}' was not found.")

@classmethod
async def remove_node(
cls,
identifier: str,
*,
force: bool = False,
) -> None:

node = cls.get_node(identifier)
await node.disconnect(force=force)

del cls.nodes[node.identifier]
__log__.info(f"Removed node '{identifier}' from the pool.")

0 comments on commit 708c733

Please sign in to comment.