Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a cors-always-allowed-origins option #8233

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions edb/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class ServerConfig(NamedTuple):

admin_ui: bool

override_cors_allowed_origins: Optional[str]


class PathPath(click.Path):
name = 'path'
Expand Down Expand Up @@ -1081,6 +1083,16 @@ def resolve_envvar_value(self, ctx: click.Context):
),
default='default',
help='Enable admin UI.'),
click.option(
'--override-cors-allowed-origins',
envvar="GEL_SERVER_OVERRIDE_CORS_ALLOWED_ORIGINS",
jaclarke marked this conversation as resolved.
Show resolved Hide resolved
cls=EnvvarResolver,
hidden=True,
help='A comma separated list of origins to always allow CORS requests '
'from regardless of the `cors_allow_orgin` config. The `*` '
'character can be used as a wildcard. Intended for use by cloud '
'to always allow the cloud UI to make requests to the instance.'
),
click.option(
'--disable-dynamic-system-config', is_flag=True,
envvar="GEL_SERVER_DISABLE_DYNAMIC_SYSTEM_CONFIG",
Expand Down
1 change: 1 addition & 0 deletions edb/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ async def _run_server(
pidfile_dir=args.pidfile_dir,
new_instance=new_instance,
admin_ui=args.admin_ui,
override_cors_origins=args.override_cors_allowed_origins,
disable_dynamic_system_config=args.disable_dynamic_system_config,
compiler_state=compiler.state,
tenant=tenant,
Expand Down
1 change: 1 addition & 0 deletions edb/server/multitenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ async def run_server(
default_auth_method=args.default_auth_method,
testmode=args.testmode,
admin_ui=args.admin_ui,
override_cors_origins=args.override_cors_allowed_origins,
disable_dynamic_system_config=args.disable_dynamic_system_config,
compiler_pool_size=args.compiler_pool_size,
compiler_pool_mode=srvargs.CompilerPoolMode.MultiTenant,
Expand Down
12 changes: 10 additions & 2 deletions edb/server/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -840,13 +840,21 @@ cdef class HttpProtocol:
config = self.tenant.get_sys_config().get('cors_allow_origins')

allowed_origins = config.value if config else None
overrides = self.server.get_override_cors_origins()

if allowed_origins is None:
if allowed_origins is None and overrides == []:
return False

origin = request.origin.decode() if request.origin else None

origin_allowed = origin is not None and (
origin in allowed_origins or '*' in allowed_origins)
any(
override.match(origin) if isinstance(override, re.Pattern)
else origin == override
for override in overrides
)
or (origin in allowed_origins or '*' in allowed_origins)
)

if origin_allowed:
response.custom_headers['Access-Control-Allow-Origin'] = origin
Expand Down
14 changes: 14 additions & 0 deletions edb/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


from __future__ import annotations
import re
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
default_auth_method: srvargs.ServerAuthMethods = (
srvargs.DEFAULT_AUTH_METHODS),
admin_ui: bool = False,
override_cors_origins: Optional[str] = None,
disable_dynamic_system_config: bool = False,
compiler_state: edbcompiler.CompilerState,
use_monitor_fs: bool = False,
Expand Down Expand Up @@ -252,6 +254,15 @@ def __init__(

self._admin_ui = admin_ui

self._override_cors_origins = [
re.compile(
'^' + origin
.replace('.', '\\.')
.replace('*', '.*') + '$'
) if '*' in origin else origin
for origin in override_cors_origins.split(',')
] if override_cors_origins else []

self._file_watch_handles = []
self._tls_certs_reload_retry_handle: Any | asyncio.TimerHandle = None

Expand Down Expand Up @@ -308,6 +319,9 @@ def in_test_mode(self):
def is_admin_ui_enabled(self):
return self._admin_ui

def get_override_cors_origins(self):
return self._override_cors_origins

def on_binary_client_created(self) -> str:
self._binary_proto_id_counter += 1

Expand Down
Loading