Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidate Discord authorization tokens on public pastes #36

Merged
merged 15 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ global_limit = { rate = 21600, per = 86400, priority = 1, bucket = "ip" }
char_limit = 300_000
file_limit = 5
name_limit = 25

[GITHUB] # optional key
token = "..." # a github token capable of creating gists, non-optional if the above key is provided
timeout = 10 # how long to wait between posting gists if there's an influx of tokens posted. Non-optional
114 changes: 107 additions & 7 deletions core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from __future__ import annotations

import asyncio
import datetime
import logging
import re
from typing import TYPE_CHECKING, Any, Self

import aiohttp
import asyncpg

from core import CONFIG
Expand All @@ -31,26 +35,114 @@

if TYPE_CHECKING:
_Pool = asyncpg.Pool[asyncpg.Record]
from types_.config import Github
from types_.github import PostGist
else:
_Pool = asyncpg.Pool


logger: logging.Logger = logging.getLogger(__name__)
DISCORD_TOKEN_REGEX: re.Pattern[str] = re.compile(r"[a-zA-Z0-9_-]{23,28}\.[a-zA-Z0-9_-]{6,7}\.[a-zA-Z0-9_-]{27,}")
LOGGER: logging.Logger = logging.getLogger(__name__)


class Database:
pool: _Pool

def __init__(self, *, dsn: str) -> None:
def __init__(self, *, dsn: str, session: aiohttp.ClientSession | None = None, github_config: Github | None) -> None:
self._dsn: str = dsn
self.session: aiohttp.ClientSession | None = session
self._handling_tokens = bool(self.session and github_config)

if self._handling_tokens:
LOGGER.info("Will handle compromised discord info.")
assert github_config # guarded by if here

self._gist_token = github_config["token"]
self._gist_timeout = github_config["timeout"]
# tokens bucket for gist posting: {paste_id: token\ntoken}
self.__tokens_bucket: dict[str, str] = {}
self.__token_lock = asyncio.Lock()
self.__token_task = asyncio.create_task(self._token_task())

async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_: Any) -> None:
task: asyncio.Task[None] | None = getattr(self, "__token_task", None)
if task:
task.cancel()

await self.close()

async def _token_task(self) -> None:
# won't run unless pre-reqs are met in __init__
while True:
if self.__tokens_bucket:
async with self.__token_lock:
await self._post_gist_of_tokens()

await asyncio.sleep(self._gist_timeout)

def _handle_discord_tokens(self, *bodies: dict[str, str], paste_id: str) -> None:
formatted_bodies = "\n".join(b["content"] for b in bodies)

tokens = list(DISCORD_TOKEN_REGEX.finditer(formatted_bodies))

if not tokens:
return

LOGGER.info(
"Discord bot token located and added to token bucket. Current bucket size is: %s", len(self.__tokens_bucket)
)

tokens = "\n".join([m[0] for m in tokens])
self.__tokens_bucket[paste_id] = tokens

async def _post_gist_of_tokens(self) -> None:
assert self.session # guarded in caller
json_payload: PostGist = {
"description": "MystBin found these Discord tokens in a public paste, and posted them here to invalidate them. If you intended to share these, please apply a password to the paste.",
"files": {},
"public": True,
}

github_headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {self._gist_token}",
"X-GitHub-Api-Version": "2022-11-28",
}

current_tokens = self.__tokens_bucket
self.__tokens_bucket = {}

for paste_id, tokens in current_tokens.items():
filename = str(datetime.datetime.now(datetime.UTC)) + "-tokens.txt"
json_payload["files"][filename] = {"content": f"https://mystb.in/{paste_id}:\n{tokens}"}

success = False

try:
async with self.session.post(
"https://api.github.com/gists", headers=github_headers, json=json_payload
) as resp:
success = resp.ok

if not success:
response_body = await resp.text()
LOGGER.error(
"Failed to create gist with token bucket with response status code %s and response body:\n\n%s",
resp.status,
response_body,
)
except (aiohttp.ClientError, aiohttp.ClientOSError) as error:
success = False
LOGGER.error("Failed to handle gist creation due to a client or operating system error", exc_info=error)

if success:
LOGGER.info("Gist created and invalidated tokens from %s pastes.", len(current_tokens))
else:
self.__tokens_bucket.update(current_tokens)

async def connect(self) -> None:
try:
pool: asyncpg.Pool[asyncpg.Record] | None = await asyncpg.create_pool(dsn=self._dsn)
Expand All @@ -64,15 +156,15 @@ async def connect(self) -> None:
await pool.execute(fp.read())

self.pool = pool
logger.info("Successfully connected to the database.")
LOGGER.info("Successfully connected to the database.")

async def close(self) -> None:
try:
await asyncio.wait_for(self.pool.close(), timeout=10)
except TimeoutError:
logger.warning("Failed to greacefully close the database connection...")
LOGGER.warning("Failed to greacefully close the database connection...")
else:
logger.info("Successfully closed the database connection.")
LOGGER.info("Successfully closed the database connection.")

async def fetch_paste(self, identifier: str, *, password: str | None) -> PasteModel | None:
assert self.pool
Expand Down Expand Up @@ -159,6 +251,8 @@ async def create_paste(self, *, data: dict[str, Any]) -> PasteModel:
tokens = [t for t in utils.TOKEN_REGEX.findall(content) if utils.validate_discord_token(t)]
if tokens:
annotation = "Contains possibly sensitive information: Discord Token(s)"
if not password:
annotation += ", which have now been invalidated."

row: asyncpg.Record | None = await connection.fetchrow(
file_query, paste.id, content, name, loc, annotation
Expand All @@ -167,7 +261,13 @@ async def create_paste(self, *, data: dict[str, Any]) -> PasteModel:
if row:
paste.files.append(FileModel(row))

return paste
if not password:
# if the user didn't provide a password (a public paste)
# we check for discord tokens
LOGGER.info("Located tokens")
self._handle_discord_tokens(*data["files"], paste_id=paste.id)

return paste

async def fetch_paste_security(self, *, token: str) -> PasteModel | None:
query: str = """SELECT * FROM pastes WHERE safety = $1"""
Expand Down
10 changes: 8 additions & 2 deletions core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging

import aiohttp
import starlette_plus
from starlette.middleware import Middleware
from starlette.routing import Mount, Route
Expand All @@ -34,11 +35,16 @@


class Application(starlette_plus.Application):
def __init__(self, *, database: Database) -> None:
def __init__(self, *, database: Database, session: aiohttp.ClientSession | None = None) -> None:
self.database: Database = database
self.session: aiohttp.ClientSession | None = session
self.schemas: SchemaGenerator | None = None

views: list[starlette_plus.View] = [HTMXView(self), APIView(self), DocsView(self)]
views: list[starlette_plus.View] = [
HTMXView(self),
APIView(self),
DocsView(self),
]
routes: list[Mount | Route] = [Mount("/static", app=StaticFiles(directory="web/static"), name="static")]

limit_redis = starlette_plus.Redis(url=CONFIG["REDIS"]["limiter"]) if CONFIG["REDIS"]["limiter"] else None
Expand Down
5 changes: 4 additions & 1 deletion launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import asyncio
import logging

import aiohttp
import starlette_plus
import uvicorn

Expand All @@ -31,7 +32,9 @@


async def main() -> None:
async with core.Database(dsn=core.CONFIG["DATABASE"]["dsn"]) as database:
async with aiohttp.ClientSession() as session, core.Database(
dsn=core.CONFIG["DATABASE"]["dsn"], session=session, github_config=core.CONFIG.get("GITHUB")
) as database:
app: core.Application = core.Application(database=database)

host: str = core.CONFIG["SERVER"]["host"]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ asyncpg-stubs
bleach
humanize
python-multipart
pyyaml
pyyaml
aiohttp
8 changes: 7 additions & 1 deletion types_/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from typing import TypedDict
from typing import NotRequired, TypedDict

import starlette_plus

Expand Down Expand Up @@ -52,9 +52,15 @@ class Pastes(TypedDict):
name_limit: int


class Github(TypedDict):
token: str
timeout: float


class Config(TypedDict):
SERVER: Server
DATABASE: Database
REDIS: Redis
LIMITS: Limits
PASTES: Pastes
GITHUB: NotRequired[Github]
29 changes: 29 additions & 0 deletions types_/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""MystBin. Share code easily.

Copyright (C) 2020-Current PythonistaGuild

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from typing import TypedDict


class GistContent(TypedDict):
content: str


class PostGist(TypedDict):
description: str
files: dict[str, GistContent]
public: bool
8 changes: 6 additions & 2 deletions views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ async def paste_post(self, request: starlette_plus.Request) -> starlette_plus.Re

Max file limit is `5`.\n\n

If the paste is regarded as public, and contains Discord authorization tokens,
then these will be invalidated upon paste creation.\n\n

requestBody:
description: The paste data. `password` and `expires` are optional.
content:
Expand Down Expand Up @@ -245,7 +248,6 @@ async def paste_post(self, request: starlette_plus.Request) -> starlette_plus.Re
type: string
example: You are requesting too fast.
"""

content_type: str | None = request.headers.get("content-type", None)
body: dict[str, Any] | str
data: dict[str, Any]
Expand All @@ -259,6 +261,7 @@ async def paste_post(self, request: starlette_plus.Request) -> starlette_plus.Re
body = (await request.body()).decode(encoding="UTF-8")

data = {"files": [{"content": body, "filename": None}]} if isinstance(body, str) else body

if resp := validate_paste(data):
return resp

Expand All @@ -270,9 +273,10 @@ async def paste_post(self, request: starlette_plus.Request) -> starlette_plus.Re
return starlette_plus.JSONResponse({"error": f'Unable to parse "expiry" parameter: {e}'}, status_code=400)

data["expires"] = expiry
data["password"] = data.get("password", None)
data["password"] = data.get("password")

paste = await self.app.database.create_paste(data=data)

to_return: dict[str, Any] = paste.serialize(exclude=["password", "password_ok"])
to_return.pop("files", None)

Expand Down