Skip to content

Commit

Permalink
feat: Add get providers from entry points
Browse files Browse the repository at this point in the history
  • Loading branch information
loonghao committed Apr 18, 2024
1 parent 284ee16 commit 604dbcf
Show file tree
Hide file tree
Showing 3 changed files with 489 additions and 499 deletions.
43 changes: 40 additions & 3 deletions notifiers/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from abc import ABC
from abc import abstractmethod
from importlib.metadata import entry_points

import jsonschema
import requests
Expand Down Expand Up @@ -340,16 +341,52 @@ def get_notifier(provider_name: str, strict: bool = False) -> Provider:
:return: :class:`Provider` or None
:raises ValueError: In case ``strict`` is True and provider not found
"""
if provider_name in _all_providers:
providers = get_all_providers()
if provider_name in providers:
log.debug("found a match for '%s', returning", provider_name)
return _all_providers[provider_name]()
return providers[provider_name]()
elif strict:
raise NoSuchNotifierError(name=provider_name)


def get_providers_from_entry_points(group_name: str = "notifiers") -> dict:
"""
Get a dictionary of plugins from the entry points based on the given group name.
This function will search for the entry points with the specified group name
and return a dictionary where the keys are the names of the entry points and
the values are the corresponding entry point values.
:param group_name: The group name of the entry points to search for.
:return: Dict: A dictionary containing the entry point names as keys and their corresponding values as values.
Example:
>>> get_providers_from_entry_points("notifiers")
{"plugin1": "package.module:PluginClass", "plugin2": "package2.module:OtherPluginClass"}
"""
result: dict = {}
points = entry_points()
for item in points.get(group_name, []):
if item.group == group_name:
result[item.name] = item.value
return result


def get_all_providers() -> dict:
"""Get all providers from the entry points and the default providers.
:return: Dict: A dictionary containing the entry point names as keys and their corresponding values as values.
"""
default_providers = _all_providers.copy()
entry_point_providers = get_providers_from_entry_points()
default_providers.update(entry_point_providers)
return default_providers


def all_providers() -> list:
"""Returns a list of all :class:`~notifiers.core.Provider` names"""
return list(_all_providers.keys())
return list(get_all_providers().keys())


def notify(provider_name: str, **kwargs) -> Response:
Expand Down
Loading

0 comments on commit 604dbcf

Please sign in to comment.