Skip to content

Commit

Permalink
Dev yyw
Browse files Browse the repository at this point in the history
  • Loading branch information
PinkGranite committed Jan 8, 2025
1 parent 3814b61 commit cbed90f
Show file tree
Hide file tree
Showing 28 changed files with 3,784 additions and 392 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
.venv/
wheelhouse/
build/
.venv/
.python-version
log/
test*
cache/
Expand All @@ -20,3 +22,5 @@ __*
!__init__.py
docs/
TODO
pycityagent-sim
pycityagent-ui
5 changes: 3 additions & 2 deletions pycityagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from .agent import Agent, CitizenAgent, InstitutionAgent
from .environment import Simulator
import logging
from .llm import SentenceEmbedding
from .simulation import AgentSimulation
import logging

# 创建一个 pycityagent 记录器
logger = logging.getLogger("pycityagent")
Expand All @@ -20,4 +21,4 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

__all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent","SentenceEmbedding",]
__all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent","SentenceEmbedding","AgentSimulation"]
113 changes: 109 additions & 4 deletions pycityagent/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""智能体模板类及其定义"""

from __future__ import annotations
import asyncio
import inspect
import json
import logging
import random
Expand All @@ -9,14 +9,17 @@
from copy import deepcopy
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional
from typing import Any, List, Optional, Type, get_type_hints
from uuid import UUID

import fastavro
from pyparsing import Dict
import ray
from mosstool.util.format_converter import dict2pb
from pycityproto.city.person.v2 import person_pb2 as person_pb2

from pycityagent.workflow import Block

from .economy import EconomyClient
from .environment import Simulator
from .environment.sim.person_service import PersonService
Expand Down Expand Up @@ -46,6 +49,8 @@ class Agent(ABC):
"""
Agent base class
"""
configurable_fields: List[str] = []
default_values: dict[str, Any] = {}

def __init__(
self,
Expand Down Expand Up @@ -101,6 +106,98 @@ def __getstate__(self):
del state["_llm_client"]
return state

@classmethod
def export_class_config(cls) -> Dict[str, Dict]:
result = {
"agent_name": cls.__name__,
"config": {},
"blocks": []
}
config = {
field: cls.default_values.get(field, "default_value")
for field in cls.configurable_fields
}
result["config"] = config
# 解析类中的注解,找到Block类型的字段
hints = get_type_hints(cls)
for attr_name, attr_type in hints.items():
if inspect.isclass(attr_type) and issubclass(attr_type, Block):
block_config = attr_type.export_class_config()
result["blocks"].append({
"name": attr_name,
"config": block_config,
"children": cls._export_subblocks(attr_type)
})
return result

@classmethod
def _export_subblocks(cls, block_cls: Type[Block]) -> List[Dict]:
children = []
hints = get_type_hints(block_cls) # 获取类的注解
for attr_name, attr_type in hints.items():
if inspect.isclass(attr_type) and issubclass(attr_type, Block):
block_config = attr_type.export_class_config()
children.append({
"name": attr_name,
"config": block_config,
"children": cls._export_subblocks(attr_type)
})
return children

@classmethod
def export_to_file(cls, filepath: str) -> None:
config = cls.export_class_config()
with open(filepath, "w") as f:
json.dump(config, f, indent=4)

@classmethod
def import_block_config(cls, config: Dict[str, List[Dict]]) -> "Agent":
agent = cls(name=config["agent_name"])

def build_block(block_data: Dict) -> Block:
block_cls = globals()[block_data["name"]]
block_instance = block_cls.import_config(block_data)
return block_instance

# 创建顶层Block
for block_data in config["blocks"]:
block = build_block(block_data)
setattr(agent, block.name.lower(), block)

return agent

@classmethod
def import_from_file(cls, filepath: str) -> "Agent":
with open(filepath, "r") as f:
config = json.load(f)
return cls.import_block_config(config)

def load_from_config(self, config: Dict[str, List[Dict]]) -> None:
"""
使用配置更新当前Agent实例的Block层次结构。
"""
# 更新当前Agent的基础参数
for field in self.configurable_fields:
if field in config["config"]:
if config["config"][field] != "default_value":
setattr(self, field, config["config"][field])

# 递归更新或创建顶层Block
for block_data in config.get("blocks", []):
block_name = block_data["name"]
existing_block = getattr(self, block_name, None)

if existing_block:
# 如果Block已经存在,则递归更新
existing_block.load_from_config(block_data)
else:
raise KeyError(f"Block '{block_name}' not found in agent '{self.__class__.__name__}'")

def load_from_file(self, filepath: str) -> None:
with open(filepath, "r") as f:
config = json.load(f)
self.load_from_config(config)

def set_messager(self, messager: Messager): # type:ignore
"""
Set the messager of the agent.
Expand Down Expand Up @@ -218,6 +315,11 @@ def copy_writer(self):
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
)
return self._pgsql_writer

async def messager_ping(self):
if self._messager is None:
raise RuntimeError("Messager is not set")
return await self._messager.ping()

async def generate_user_survey_response(self, survey: dict) -> str:
"""生成回答 —— 可重写
Expand Down Expand Up @@ -527,6 +629,8 @@ async def run(self) -> None:
统一的Agent执行入口
当_blocked为True时,不执行forward方法
"""
if self._messager is not None:
await self._messager.ping.remote()
if not self._blocked:
await self.forward()

Expand Down Expand Up @@ -621,10 +725,11 @@ async def _bind_to_economy(self):
except:
pass
person_id = await self.memory.get("id")
currency = await self.memory.get("currency")
await self._economy_client.add_agents(
{
"id": person_id,
"currency": await self.memory.get("currency"),
"currency": currency,
}
)
self._has_bound_to_economy = True
Expand Down
20 changes: 20 additions & 0 deletions pycityagent/cityagent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .societyagent import SocietyAgent
from .firmagent import FirmAgent
from .bankagent import BankAgent
from .nbsagent import NBSAgent
from .governmentagent import GovernmentAgent
from .memory_config import memory_config_societyagent, memory_config_government, memory_config_firm, memory_config_bank, memory_config_nbs

__all__ = [
"SocietyAgent",
"FirmAgent",
"BankAgent",
"NBSAgent",
"GovernmentAgent",
"memory_config_societyagent",
"memory_config_government",
"memory_config_firm",
"memory_config_bank",
"memory_config_nbs",
]

54 changes: 54 additions & 0 deletions pycityagent/cityagent/bankagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
from typing import Optional

import numpy as np
from pycityagent import Simulator, InstitutionAgent
from pycityagent.llm.llm import LLM
from pycityagent.economy import EconomyClient
from pycityagent.message import Messager
from pycityagent.memory import Memory
import logging

logger = logging.getLogger("pycityagent")

class BankAgent(InstitutionAgent):
def __init__(self,
name: str,
llm_client: Optional[LLM] = None,
simulator: Optional[Simulator] = None,
memory: Optional[Memory] = None,
economy_client: Optional[EconomyClient] = None,
messager: Optional[Messager] = None,
avro_file: Optional[dict] = None,
) -> None:
super().__init__(name=name, llm_client=llm_client, simulator=simulator, memory=memory, economy_client=economy_client, messager=messager, avro_file=avro_file)
self.initailzed = False
self.last_time_trigger = None
self.time_diff = 30 * 24 * 60 * 60
self.forward_times = 0

async def month_trigger(self):
now_time = await self.simulator.get_time()
if self.last_time_trigger is None:
self.last_time_trigger = now_time
return False
if now_time - self.last_time_trigger >= self.time_diff:
self.last_time_trigger = now_time
return True
return False

async def gather_messages(self, agent_ids, content):
infos = await super().gather_messages(agent_ids, content)
return [info['content'] for info in infos]

async def forward(self):
if await self.month_trigger():
citizens = await self.memory.get("citizens")
while True:
agents_forward = await self.gather_messages(citizens, 'forward')
if np.all(np.array(agents_forward) > self.forward_times):
break
await asyncio.sleep(1)
self.forward_times += 1
for uuid in citizens:
await self.send_message_to_agent(uuid, f"bank_forward@{self.forward_times}")
20 changes: 20 additions & 0 deletions pycityagent/cityagent/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .mobility_block import MobilityBlock
from .cognition_block import CognitionBlock
from .plan_block import PlanBlock
from .needs_block import NeedsBlock
from .social_block import SocialBlock
from .economy_block import EconomyBlock
from .other_block import OtherBlock
from .time_block import TimeBlock

__all__ = [
"MobilityBlock",
"CognitionBlock",
"PlanBlock",
"NeedsBlock",
"SocialBlock",
"EconomyBlock",
"OtherBlock",
"LongTermDecisionBlock",
"TimeBlock",
]
Loading

0 comments on commit cbed90f

Please sign in to comment.