-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathfast_api_depends_test.py
74 lines (59 loc) · 2.26 KB
/
fast_api_depends_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
from websockets.exceptions import InvalidStatus
from multiprocessing import Process
import pytest
import uvicorn
from fastapi import APIRouter, Depends, FastAPI, Header, WebSocket
from fastapi_websocket_rpc.rpc_methods import RpcUtilityMethods
from fastapi_websocket_rpc.utils import gen_uid
from fastapi_websocket_rpc.websocket_rpc_client import WebSocketRpcClient
from fastapi_websocket_rpc.websocket_rpc_endpoint import WebsocketRPCEndpoint
# Configurable
PORT = int(os.environ.get("PORT") or "8000")
# Random ID
CLIENT_ID = gen_uid()
uri = f"ws://localhost:{PORT}/ws/{CLIENT_ID}"
# A 'secret' to be checked by the server
SECRET_TOKEN = "fake-super-secret-token"
async def check_token_header(websocket:WebSocket, x_token: str = Header(...)):
if x_token != SECRET_TOKEN:
await websocket.close(403)
return None
def setup_server():
app = FastAPI()
router = APIRouter()
endpoint = WebsocketRPCEndpoint(RpcUtilityMethods())
@router.websocket("/ws/{client_id}")
async def websocket_rpc_endpoint(websocket: WebSocket, client_id: str, token=Depends(check_token_header)):
await endpoint.main_loop(websocket, client_id)
app.include_router(router)
uvicorn.run(app, port=PORT)
@pytest.fixture()
def server():
# Run the server as a separate process
proc = Process(target=setup_server, args=(), daemon=True)
proc.start()
yield proc
proc.kill() # Cleanup after test
@pytest.mark.asyncio
async def test_valid_token(server):
"""
Test basic RPC with a simple echo
"""
async with WebSocketRpcClient(uri, RpcUtilityMethods(), default_response_timeout=4, additional_headers=[("X-TOKEN", SECRET_TOKEN)]) as client:
text = "Hello World!"
response = await client.other.echo(text=text)
assert response.result == text
@pytest.mark.asyncio
async def test_invalid_token(server):
"""
Test basic RPC with a simple echo
"""
try:
async with WebSocketRpcClient(uri, RpcUtilityMethods(), default_response_timeout=4, additional_headers=[("X-TOKEN", "bad-token")]) as client:
assert client is not None
# if we got here - the server didn't reject us
assert False
except InvalidStatus as e:
assert e.response.status_code == 403