diff --git a/CHANGELOG.md b/CHANGELOG.md index 924e55d14..a18801c09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,10 @@ - **loader**: ability to run in-memory models - **schedulers**: ability to create model-less schedulers - **quantiation**: code refactor into dedicated module +- **Authentication**: + - perform auth check on ui startup + - unified standard and modern-ui authentication method + - cleanup auth logging - **Fixes**: - non-full vae decode - send-to image transfer diff --git a/javascript/login.js b/javascript/login.js new file mode 100644 index 000000000..41bb5e941 --- /dev/null +++ b/javascript/login.js @@ -0,0 +1,72 @@ +const loginCSS = ` + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: var(--background-fill-primary); + color: var(--body-text-color-subdued); + font-family: monospace; + z-index: 100; +`; + +const loginHTML = ` +
+

Login

+ + + + +
+ +
+`; + +function forceLogin() { + const form = document.createElement('form'); + form.method = 'POST'; + form.action = '/login'; + form.id = 'loginForm'; + form.style.cssText = loginCSS; + form.innerHTML = loginHTML; + document.body.appendChild(form); + const username = form.querySelector('#loginUsername'); + const password = form.querySelector('#loginPassword'); + const status = form.querySelector('#loginStatus'); + + form.addEventListener('submit', (event) => { + event.preventDefault(); + const formData = new FormData(form); + formData.append('username', username.value); + formData.append('password', password.value); + console.warn('login', formData); + fetch('/login', { + method: 'POST', + body: formData, + }) + .then(async (res) => { + const json = await res.json(); + const txt = `${res.status}: ${res.statusText} - ${json.detail}`; + status.textContent = txt; + console.log('login', txt); + if (res.status === 200) location.reload(); + }) + .catch((err) => { + status.textContent = err; + console.error('login', err); + }); + }); +} + +function loginCheck() { + fetch('/login_check', {}) + .then((res) => { + if (res.status === 200) console.log('login ok'); + else forceLogin(); + }) + .catch((err) => { + console.error('login', err); + }); +} + +window.onload = loginCheck; diff --git a/modules/api/middleware.py b/modules/api/middleware.py index 7eb2c40e8..b5f02bd60 100644 --- a/modules/api/middleware.py +++ b/modules/api/middleware.py @@ -45,7 +45,7 @@ async def log_and_time(req: Request, call_next): if (cmd_opts.api_log or cmd_opts.api_only) and endpoint.startswith('/sdapi'): if '/sdapi/v1/log' in endpoint or '/sdapi/v1/browser' in endpoint: return res - log.info('API {user} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation + log.info('API user={user} code={code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation user = app.tokens.get(token) if hasattr(app, 'tokens') else None, code = res.status_code, ver = req.scope.get('http_version', '0.0'), @@ -69,10 +69,14 @@ def handle_exception(req: Request, e: Exception): "body": vars(e).get('body', ''), "errors": str(e), } + if err['code'] == 401 and 'file=' in req.url.path: # dont spam with unauth + return JSONResponse(status_code=err['code'], content=jsonable_encoder(err)) + log.error(f"API error: {req.method}: {req.url} {err}") + if not isinstance(e, HTTPException) and err['error'] != 'TypeError': # do not print backtrace on known httpexceptions errors.display(e, 'HTTP API', [anyio, fastapi, uvicorn, starlette]) - elif err['code'] == 404 or err['code'] == 401: + elif err['code'] in [404, 401, 400]: pass else: log.debug(e, exc_info=True) # print stack trace diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py index ccf5f8c0d..c8847c7d8 100644 --- a/modules/ui_javascript.py +++ b/modules/ui_javascript.py @@ -17,12 +17,13 @@ def webpath(fn): def html_head(): head = '' main = ['script.js'] + skip = ['login.js'] for js in main: script_js = os.path.join(script_path, "javascript", js) head += f'\n' added = [] for script in modules.scripts.list_scripts("javascript", ".js"): - if script.filename in main: + if script.filename in main or script.filename in skip: continue head += f'\n' added.append(script.path) @@ -43,6 +44,14 @@ def html_body(): return body +def html_login(): + fn = os.path.join(script_path, "javascript", "login.js") + with open(fn, 'r', encoding='utf8') as f: + inline = f.read() + js = f'\n' + return js + + def html_css(css: str): def stylesheet(fn): return f'' @@ -78,17 +87,19 @@ def stylesheet(fn): def reload_javascript(): base_css = theme.reload_gradio_theme() - head = html_head() - css = html_css(base_css) - body = html_body() title = 'SD.Next' manifest = f'' + login = html_login() + js = html_head() + css = html_css(base_css) + body = html_body() def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{title}'.encode("utf8")) - res.body = res.body.replace(b'', f'{head}'.encode("utf8")) res.body = res.body.replace(b'', f'{manifest}'.encode("utf8")) + res.body = res.body.replace(b'', f'{login}'.encode("utf8")) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) res.body = res.body.replace(b'', f'{css}{body}'.encode("utf8")) lines = res.body.decode("utf8").split('\n') for line in lines: