Skip to content

Commit

Permalink
Add overload for async singleton call with HassKey (home-assistant#13…
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p authored Jan 17, 2025
1 parent 2ec971a commit abc256f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
9 changes: 6 additions & 3 deletions homeassistant/components/esphome/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.util.hass_dict import HassKey

from .const import DOMAIN
from .coordinator import ESPHomeDashboardCoordinator

_LOGGER = logging.getLogger(__name__)


KEY_DASHBOARD_MANAGER = "esphome_dashboard_manager"
KEY_DASHBOARD_MANAGER: HassKey[ESPHomeDashboardManager] = HassKey(
"esphome_dashboard_manager"
)

STORAGE_KEY = "esphome.dashboard"
STORAGE_VERSION = 1
Expand All @@ -33,7 +36,7 @@ async def async_setup(hass: HomeAssistant) -> None:
await async_get_or_create_dashboard_manager(hass)


@singleton(KEY_DASHBOARD_MANAGER)
@singleton(KEY_DASHBOARD_MANAGER, async_=True)
async def async_get_or_create_dashboard_manager(
hass: HomeAssistant,
) -> ESPHomeDashboardManager:
Expand Down Expand Up @@ -140,7 +143,7 @@ def async_get_dashboard(hass: HomeAssistant) -> ESPHomeDashboardCoordinator | No
where manager can be an asyncio.Event instead of the actual manager
because the singleton decorator is not yet done.
"""
manager: ESPHomeDashboardManager | None = hass.data.get(KEY_DASHBOARD_MANAGER)
manager = hass.data.get(KEY_DASHBOARD_MANAGER)
return manager.async_get() if manager else None


Expand Down
70 changes: 62 additions & 8 deletions homeassistant/helpers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from collections.abc import Callable, Coroutine
import functools
from typing import Any, cast, overload
from typing import Any, Literal, assert_type, cast, overload

from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey

type _FuncType[_T] = Callable[[HomeAssistant], _T]
type _Coro[_T] = Coroutine[Any, Any, _T]


@overload
def singleton[_T](
data_key: HassKey[_T], *, async_: Literal[True]
) -> Callable[[_FuncType[_Coro[_T]]], _FuncType[_Coro[_T]]]: ...


@overload
Expand All @@ -24,29 +31,37 @@ def singleton[_T](
def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...


def singleton[_T](data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
def singleton[_S, _T, _U](
data_key: Any, *, async_: bool = False
) -> Callable[[_FuncType[_S]], _FuncType[_S]]:
"""Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled.
"""

def wrapper(func: _FuncType[_T]) -> _FuncType[_T]:
@overload
def wrapper(func: _FuncType[_Coro[_T]]) -> _FuncType[_Coro[_T]]: ...

@overload
def wrapper(func: _FuncType[_U]) -> _FuncType[_U]: ...

def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]:
"""Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func):

@functools.lru_cache(maxsize=1)
@bind_hass
@functools.wraps(func)
def wrapped(hass: HomeAssistant) -> _T:
def wrapped(hass: HomeAssistant) -> _U:
if data_key not in hass.data:
hass.data[data_key] = func(hass)
return cast(_T, hass.data[data_key])
return cast(_U, hass.data[data_key])

return wrapped

@bind_hass
@functools.wraps(func)
async def async_wrapped(hass: HomeAssistant) -> Any:
async def async_wrapped(hass: HomeAssistant) -> _T:
if data_key not in hass.data:
evt = hass.data[data_key] = asyncio.Event()
result = await func(hass)
Expand All @@ -62,6 +77,45 @@ async def async_wrapped(hass: HomeAssistant) -> Any:

return cast(_T, obj_or_evt)

return async_wrapped # type: ignore[return-value]
return async_wrapped

return wrapper


async def _test_singleton_typing(hass: HomeAssistant) -> None:
"""Test singleton overloads work as intended.
This is tested during the mypy run. Do not move it to 'tests'!
"""
# Test HassKey
key = HassKey[int]("key")

@singleton(key)
def func(hass: HomeAssistant) -> int:
return 2

@singleton(key, async_=True)
async def async_func(hass: HomeAssistant) -> int:
return 2

assert_type(func(hass), int)
assert_type(await async_func(hass), int)

# Test invalid use of 'async_' with sync function
@singleton(key, async_=True) # type: ignore[arg-type]
def func_error(hass: HomeAssistant) -> int:
return 2

# Test string key
other_key = "key"

@singleton(other_key)
def func2(hass: HomeAssistant) -> str:
return ""

@singleton(other_key)
async def async_func2(hass: HomeAssistant) -> str:
return ""

assert_type(func2(hass), str)
assert_type(await async_func2(hass), str)

0 comments on commit abc256f

Please sign in to comment.