Skip to content

Commit

Permalink
add lock_decorator for Tools
Browse files Browse the repository at this point in the history
  • Loading branch information
chenchenplus committed Jan 8, 2025
1 parent 8656a8f commit eef621d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 44 deletions.
14 changes: 7 additions & 7 deletions pycityagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from datetime import datetime, timezone
from enum import Enum
from typing import Any, List, Optional, Type, get_type_hints
from typing import Any, Optional, Type, get_type_hints
from uuid import UUID

import fastavro
Expand Down Expand Up @@ -49,7 +49,7 @@ class Agent(ABC):
"""
Agent base class
"""
configurable_fields: List[str] = []
configurable_fields: list[str] = []
default_values: dict[str, Any] = {}

def __init__(
Expand Down Expand Up @@ -107,7 +107,7 @@ def __getstate__(self):
return state

@classmethod
def export_class_config(cls) -> Dict[str, Dict]:
def export_class_config(cls) -> dict[str, Dict]:
result = {
"agent_name": cls.__name__,
"config": {},
Expand All @@ -131,7 +131,7 @@ def export_class_config(cls) -> Dict[str, Dict]:
return result

@classmethod
def _export_subblocks(cls, block_cls: Type[Block]) -> List[Dict]:
def _export_subblocks(cls, block_cls: Type[Block]) -> list[Dict]:
children = []
hints = get_type_hints(block_cls) # 获取类的注解
for attr_name, attr_type in hints.items():
Expand All @@ -151,7 +151,7 @@ def export_to_file(cls, filepath: str) -> None:
json.dump(config, f, indent=4)

@classmethod
def import_block_config(cls, config: Dict[str, List[Dict]]) -> "Agent":
def import_block_config(cls, config: dict[str, list[Dict]]) -> "Agent":
agent = cls(name=config["agent_name"])

def build_block(block_data: Dict) -> Block:
Expand All @@ -172,7 +172,7 @@ def import_from_file(cls, filepath: str) -> "Agent":
config = json.load(f)
return cls.import_block_config(config)

def load_from_config(self, config: Dict[str, List[Dict]]) -> None:
def load_from_config(self, config: dict[str, list[Dict]]) -> None:
"""
使用配置更新当前Agent实例的Block层次结构。
"""
Expand All @@ -185,7 +185,7 @@ def load_from_config(self, config: Dict[str, List[Dict]]) -> None:
# 递归更新或创建顶层Block
for block_data in config.get("blocks", []):
block_name = block_data["name"]
existing_block = getattr(self, block_name, None)
existing_block = getattr(self, block_name, None) # type:ignore

if existing_block:
# 如果Block已经存在,则递归更新
Expand Down
39 changes: 22 additions & 17 deletions pycityagent/simulation/agentgroup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
from collections.abc import Callable
import json
import logging
import time
import uuid
from collections.abc import Callable
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, Type, Union
Expand Down Expand Up @@ -34,7 +34,10 @@ def __init__(
self,
agent_class: Union[type[Agent], list[type[Agent]]],
number_of_agents: Union[int, list[int]],
memory_config_function_group: Union[Callable[[], tuple[dict, dict, dict]], list[Callable[[], tuple[dict, dict, dict]]]],
memory_config_function_group: Union[
Callable[[], tuple[dict, dict, dict]],
list[Callable[[], tuple[dict, dict, dict]]],
],
config: dict,
exp_id: str | UUID,
exp_name: str,
Expand Down Expand Up @@ -81,14 +84,14 @@ def __init__(
# prepare Messager
if "mqtt" in config["simulator_request"]:
self.messager = Messager.remote(
hostname=config["simulator_request"]["mqtt"]["server"],
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
port=config["simulator_request"]["mqtt"]["port"],
username=config["simulator_request"]["mqtt"].get("username", None),
password=config["simulator_request"]["mqtt"].get("password", None),
)
else:
self.messager = None

self.message_dispatch_task = None
self._pgsql_writer = pgsql_writer
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
Expand Down Expand Up @@ -168,39 +171,36 @@ def __init__(
@property
def agent_count(self):
return self.number_of_agents

@property
def agent_uuids(self):
return list(self.id2agent.keys())

@property
def agent_type(self):
return self.agent_class

def get_agent_count(self):
return self.agent_count

def get_agent_uuids(self):
return self.agent_uuids

def get_agent_type(self):
return self.agent_type

async def __aexit__(self, exc_type, exc_value, traceback):
self.message_dispatch_task.cancel() # type: ignore
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore

async def __aexit__(self, exc_type, exc_value, traceback):
self.message_dispatch_task.cancel() # type: ignore
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore

async def init_agents(self):
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
for agent in self.agents:
await agent.bind_to_simulator() # type: ignore
self.id2agent = {agent._uuid: agent for agent in self.agents}
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
assert self.messager is not None
await self.messager.connect.remote()
if await self.messager.is_connected.remote():
await self.messager.start_listening.remote()
Expand Down Expand Up @@ -293,24 +293,28 @@ async def init_agents(self):
self.initialized = True
logger.debug(f"-----AgentGroup {self._uuid} initialized")

async def filter(self,
types: Optional[list[Type[Agent]]] = None,
keys: Optional[list[str]] = None,
values: Optional[list[Any]] = None) -> list[str]:
async def filter(
self,
types: Optional[list[Type[Agent]]] = None,
keys: Optional[list[str]] = None,
values: Optional[list[Any]] = None,
) -> list[str]:
filtered_uuids = []
for agent in self.agents:
add = True
if types:
if agent.__class__ in types:
if keys:
for key in keys:
assert values is not None
if not agent.memory.get(key) == values[keys.index(key)]:
add = False
break
if add:
filtered_uuids.append(agent._uuid)
elif keys:
for key in keys:
assert values is not None
if not agent.memory.get(key) == values[keys.index(key)]:
add = False
break
Expand All @@ -335,6 +339,7 @@ async def update(self, target_agent_uuid: str, target_key: str, content: Any):
async def message_dispatch(self):
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
while True:
assert self.messager is not None
if not await self.messager.is_connected.remote():
logger.warning(
"Messager is not connected. Skipping message processing."
Expand Down
39 changes: 20 additions & 19 deletions pycityagent/workflow/tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Any, Optional, Union
import asyncio
import time
from collections import defaultdict
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union

from mlflow.entities import Metric
import time

from ..environment import LEVEL_ONE_PRE, POI_TYPE_DICT
from ..agent import Agent
from ..utils.decorators import lock_decorator
from ..environment import (LEVEL_ONE_PRE, POI_TYPE_DICT, AoiService,
PersonService)
from ..workflow import Block


class Tool:
Expand Down Expand Up @@ -34,31 +40,23 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
raise NotImplementedError

@property
def agent(self):
def agent(self) -> Agent:
instance = self._instance # type:ignore
if not isinstance(instance, self._get_agent_class()):
if not isinstance(instance, Agent):
raise RuntimeError(
f"Tool bind to object `{type(instance).__name__}`, not an `Agent` object!"
)
return instance

@property
def block(self):
def block(self) -> Block:
instance = self._instance # type:ignore
if not isinstance(instance, self._get_block_class()):
if not isinstance(instance, Block):
raise RuntimeError(
f"Tool bind to object `{type(instance).__name__}`, not an `Block` object!"
)
return instance

def _get_agent_class(self):
from ..agent import Agent
return Agent

def _get_block_class(self):
from ..workflow import Block
return Block


class GetMap(Tool):
"""Retrieve the map from the simulator. Can be bound only to an `Agent` instance."""
Expand Down Expand Up @@ -140,7 +138,7 @@ async def __call__(

class UpdateWithSimulator(Tool):
def __init__(self) -> None:
pass
self._lock = asyncio.Lock()

async def _update_motion_with_sim(
self,
Expand All @@ -164,6 +162,7 @@ async def _update_motion_with_sim(
except KeyError as e:
continue

@lock_decorator
async def __call__(
self,
):
Expand All @@ -173,8 +172,9 @@ async def __call__(

class ResetAgentPosition(Tool):
def __init__(self) -> None:
pass
self._lock = asyncio.Lock()

@lock_decorator
async def __call__(
self,
aoi_id: Optional[int] = None,
Expand All @@ -198,7 +198,8 @@ def __init__(self, log_batch_size: int = 100) -> None:
self._log_batch_size = log_batch_size
# TODO: support other log types
self.metric_log_cache: dict[str, list[Metric]] = defaultdict(list)

self._lock = asyncio.Lock()
@lock_decorator
async def __call__(
self,
metric: Union[Sequence[Union[Metric, dict]], Union[Metric, dict]],
Expand Down Expand Up @@ -230,7 +231,7 @@ async def __call__(
_cache = _cache[batch_size:]
if clear_cache:
await self._clear_cache()

@lock_decorator
async def _clear_cache(
self,
):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pycityagent"
version = "2.0.0a48" # change it for each release
version = "2.0.0a49" # change it for each release
description = "LLM-based city environment agent building library"
authors = [
{ name = "Yuwei Yan", email = "[email protected]" },
Expand Down

0 comments on commit eef621d

Please sign in to comment.