diff --git a/CHANGELOG.md b/CHANGELOG.md
index 924e55d14..a18801c09 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -54,6 +54,10 @@
- **loader**: ability to run in-memory models
- **schedulers**: ability to create model-less schedulers
- **quantiation**: code refactor into dedicated module
+- **Authentication**:
+ - perform auth check on ui startup
+ - unified standard and modern-ui authentication method
+ - cleanup auth logging
- **Fixes**:
- non-full vae decode
- send-to image transfer
diff --git a/javascript/login.js b/javascript/login.js
new file mode 100644
index 000000000..41bb5e941
--- /dev/null
+++ b/javascript/login.js
@@ -0,0 +1,72 @@
+const loginCSS = `
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background: var(--background-fill-primary);
+ color: var(--body-text-color-subdued);
+ font-family: monospace;
+ z-index: 100;
+`;
+
+const loginHTML = `
+
+
Login
+
+
+
+
+
+
+
+`;
+
+function forceLogin() {
+ const form = document.createElement('form');
+ form.method = 'POST';
+ form.action = '/login';
+ form.id = 'loginForm';
+ form.style.cssText = loginCSS;
+ form.innerHTML = loginHTML;
+ document.body.appendChild(form);
+ const username = form.querySelector('#loginUsername');
+ const password = form.querySelector('#loginPassword');
+ const status = form.querySelector('#loginStatus');
+
+ form.addEventListener('submit', (event) => {
+ event.preventDefault();
+ const formData = new FormData(form);
+ formData.append('username', username.value);
+ formData.append('password', password.value);
+ console.warn('login', formData);
+ fetch('/login', {
+ method: 'POST',
+ body: formData,
+ })
+ .then(async (res) => {
+ const json = await res.json();
+ const txt = `${res.status}: ${res.statusText} - ${json.detail}`;
+ status.textContent = txt;
+ console.log('login', txt);
+ if (res.status === 200) location.reload();
+ })
+ .catch((err) => {
+ status.textContent = err;
+ console.error('login', err);
+ });
+ });
+}
+
+function loginCheck() {
+ fetch('/login_check', {})
+ .then((res) => {
+ if (res.status === 200) console.log('login ok');
+ else forceLogin();
+ })
+ .catch((err) => {
+ console.error('login', err);
+ });
+}
+
+window.onload = loginCheck;
diff --git a/modules/api/middleware.py b/modules/api/middleware.py
index 7eb2c40e8..b5f02bd60 100644
--- a/modules/api/middleware.py
+++ b/modules/api/middleware.py
@@ -45,7 +45,7 @@ async def log_and_time(req: Request, call_next):
if (cmd_opts.api_log or cmd_opts.api_only) and endpoint.startswith('/sdapi'):
if '/sdapi/v1/log' in endpoint or '/sdapi/v1/browser' in endpoint:
return res
- log.info('API {user} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation
+ log.info('API user={user} code={code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation
user = app.tokens.get(token) if hasattr(app, 'tokens') else None,
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
@@ -69,10 +69,14 @@ def handle_exception(req: Request, e: Exception):
"body": vars(e).get('body', ''),
"errors": str(e),
}
+ if err['code'] == 401 and 'file=' in req.url.path: # dont spam with unauth
+ return JSONResponse(status_code=err['code'], content=jsonable_encoder(err))
+
log.error(f"API error: {req.method}: {req.url} {err}")
+
if not isinstance(e, HTTPException) and err['error'] != 'TypeError': # do not print backtrace on known httpexceptions
errors.display(e, 'HTTP API', [anyio, fastapi, uvicorn, starlette])
- elif err['code'] == 404 or err['code'] == 401:
+ elif err['code'] in [404, 401, 400]:
pass
else:
log.debug(e, exc_info=True) # print stack trace
diff --git a/modules/ui_javascript.py b/modules/ui_javascript.py
index ccf5f8c0d..c8847c7d8 100644
--- a/modules/ui_javascript.py
+++ b/modules/ui_javascript.py
@@ -17,12 +17,13 @@ def webpath(fn):
def html_head():
head = ''
main = ['script.js']
+ skip = ['login.js']
for js in main:
script_js = os.path.join(script_path, "javascript", js)
head += f'\n'
added = []
for script in modules.scripts.list_scripts("javascript", ".js"):
- if script.filename in main:
+ if script.filename in main or script.filename in skip:
continue
head += f'\n'
added.append(script.path)
@@ -43,6 +44,14 @@ def html_body():
return body
+def html_login():
+ fn = os.path.join(script_path, "javascript", "login.js")
+ with open(fn, 'r', encoding='utf8') as f:
+ inline = f.read()
+ js = f'\n'
+ return js
+
+
def html_css(css: str):
def stylesheet(fn):
return f''
@@ -78,17 +87,19 @@ def stylesheet(fn):
def reload_javascript():
base_css = theme.reload_gradio_theme()
- head = html_head()
- css = html_css(base_css)
- body = html_body()
title = 'SD.Next'
manifest = f''
+ login = html_login()
+ js = html_head()
+ css = html_css(base_css)
+ body = html_body()
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'', f'{title}'.encode("utf8"))
- res.body = res.body.replace(b'', f'{head}'.encode("utf8"))
res.body = res.body.replace(b'', f'{manifest}'.encode("utf8"))
+ res.body = res.body.replace(b'', f'{login}'.encode("utf8"))
+ res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
res.body = res.body.replace(b'