From b74e2a6cc1b1d69dda6298bac22ecefa539859c4 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Thu, 13 Feb 2025 20:56:54 +0100 Subject: [PATCH] auto hide badge for pro+ users for cloud deployments --- reflex/app.py | 19 ++++++++++++++++--- reflex/config.py | 9 +++++++-- reflex/constants/__init__.py | 2 ++ reflex/constants/compiler.py | 9 +++++++++ reflex/reflex.py | 12 +++++++++--- reflex/utils/exec.py | 9 +++++++++ reflex/utils/prerequisites.py | 26 ++++++++++++++++++++------ 7 files changed, 72 insertions(+), 14 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index d290b8f49fd..bec42602044 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -108,7 +108,7 @@ prerequisites, types, ) -from reflex.utils.exec import is_prod_mode, is_testing_env +from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar if TYPE_CHECKING: @@ -198,14 +198,17 @@ def default_overlay_component() -> Component: Returns: The default overlay_component, which is a connection_modal. """ - config = get_config() from reflex.components.component import memo def default_overlay_components(): return Fragment.create( connection_pulser(), connection_toaster(), - *([backend_disabled()] if config.is_reflex_cloud else []), + *( + [backend_disabled()] + if get_compile_context() == constants.CompileContext.DEPLOY + else [] + ), *codespaces.codespaces_auto_redirect(), ) @@ -1071,6 +1074,16 @@ def memoized_toast_provider(): self._validate_var_dependencies() self._setup_overlay_component() + + if config.show_built_with_reflex is None: + if ( + get_compile_context() == constants.CompileContext.DEPLOY + and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] + ): + config.show_built_with_reflex = False + else: + config.show_built_with_reflex = True + if is_prod_mode() and config.show_built_with_reflex: self._setup_sticky_badge() diff --git a/reflex/config.py b/reflex/config.py index 233087938d2..6512c9b89f4 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -484,6 +484,11 @@ class PerformanceMode(enum.Enum): class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" + # Indicate the current command that was invoked in the reflex CLI. + REFLEX_COMPILE_CONTEXT: EnvVar[constants.CompileContext] = env_var( + constants.CompileContext.UNDEFINED, internal=True + ) + # Whether to use npm over bun to install frontend packages. REFLEX_USE_NPM: EnvVar[bool] = env_var(False) @@ -529,7 +534,7 @@ class EnvironmentVariables: REFLEX_COMPILE_THREADS: EnvVar[Optional[int]] = env_var(None) # The directory to store reflex dependencies. - REFLEX_DIR: EnvVar[Path] = env_var(Path(constants.Reflex.DIR)) + REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) # Whether to print the SQL queries if the log level is INFO or lower. SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False) @@ -737,7 +742,7 @@ class Config: # pyright: ignore [reportIncompatibleVariableOverride] env_file: Optional[str] = None # Whether to display the sticky "Built with Reflex" badge on all pages. - show_built_with_reflex: bool = True + show_built_with_reflex: Optional[bool] = None # Whether the app is running in the reflex cloud environment. is_reflex_cloud: bool = False diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index f5946bf5e90..5a918338dd4 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -25,6 +25,7 @@ from .compiler import ( NOCOMPILE_FILE, SETTER_PREFIX, + CompileContext, CompileVars, ComponentName, Ext, @@ -65,6 +66,7 @@ ColorMode, Config, COOKIES, + CompileContext, ComponentName, CustomComponents, DefaultPage, diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 9bc9978dc41..40134c15bba 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -111,6 +111,15 @@ def zip(self): return self.value.lower() + Ext.ZIP +class CompileContext(str, Enum): + """The context in which the compiler is running.""" + + RUN = "run" + EXPORT = "export" + DEPLOY = "deploy" + UNDEFINED = "undefined" + + class Imports(SimpleNamespace): """Common sets of import vars.""" diff --git a/reflex/reflex.py b/reflex/reflex.py index 41048555161..8a0ffa495b8 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -196,7 +196,7 @@ def _run( prerequisites.check_latest_package_version(constants.Reflex.MODULE_NAME) if frontend: - if not config.show_built_with_reflex: + if config.show_built_with_reflex is False: # The sticky badge may be disabled at runtime for team/enterprise tiers. prerequisites.check_config_option_in_tier( option_name="show_built_with_reflex", @@ -306,6 +306,8 @@ def run( if frontend and backend: console.error("Cannot use both --frontend-only and --backend-only options.") raise typer.Exit(1) + + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.RUN) environment.REFLEX_BACKEND_ONLY.set(backend) environment.REFLEX_FRONTEND_ONLY.set(frontend) @@ -352,17 +354,19 @@ def export( from reflex.utils import export as export_utils from reflex.utils import prerequisites + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.EXPORT) + frontend, backend = prerequisites.check_running_mode(frontend, backend) if prerequisites.needs_reinit(frontend=frontend or not backend): _init(name=config.app_name, loglevel=loglevel) - if frontend and not config.show_built_with_reflex: + if frontend and config.show_built_with_reflex is False: # The sticky badge may be disabled on export for team/enterprise tiers. prerequisites.check_config_option_in_tier( option_name="show_built_with_reflex", allowed_tiers=["team", "enterprise"], - fallback_value=False, + fallback_value=True, help_link=SHOW_BUILT_WITH_REFLEX_INFO, ) @@ -560,6 +564,8 @@ def deploy( check_version() + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.DEPLOY) + if not config.show_built_with_reflex: # The sticky badge may be disabled on deploy for pro/team/enterprise tiers. prerequisites.check_config_option_in_tier( diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index b16aaea1cc9..5474ae82a0c 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -584,3 +584,12 @@ def is_prod_mode() -> bool: """ current_mode = environment.REFLEX_ENV_MODE.get() return current_mode == constants.Env.PROD + + +def get_compile_context() -> constants.CompileContext: + """Check if the app is compiled for deploy. + + Returns: + Whether the app is being compiled for deploy. + """ + return environment.REFLEX_COMPILE_CONTEXT.get() diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 3cd65a7eb92..145b5324c1f 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -2001,6 +2001,22 @@ def is_generation_hash(template: str) -> bool: return re.match(r"^[0-9a-f]{32,}$", template) is not None +def get_user_tier(): + """Get the current user's tier. + + Returns: + The current user's tier. + """ + from reflex_cli.v2.utils import hosting + + authenticated_token = hosting.authenticated_token() + return ( + authenticated_token[1].get("tier", "").lower() + if authenticated_token[0] + else "anonymous" + ) + + def check_config_option_in_tier( option_name: str, allowed_tiers: list[str], @@ -2015,23 +2031,21 @@ def check_config_option_in_tier( fallback_value: The fallback value if the option is not allowed. help_link: The help link to show to a user that is authenticated. """ - from reflex_cli.v2.utils import hosting - config = get_config() - authenticated_token = hosting.authenticated_token() - if not authenticated_token[0]: + current_tier = get_user_tier() + + if current_tier == "anonymous": the_remedy = ( "You are currently logged out. Run `reflex login` to access this option." ) - current_tier = "anonymous" else: - current_tier = authenticated_token[1].get("tier", "").lower() the_remedy = ( f"Your current subscription tier is `{current_tier}`. " f"Please upgrade to {allowed_tiers} to access this option. " ) if help_link: the_remedy += f"See {help_link} for more information." + if current_tier not in allowed_tiers: console.warn(f"Config option `{option_name}` is restricted. {the_remedy}") setattr(config, option_name, fallback_value)