Skip to content

Commit

Permalink
Preliminary implementation of Message Interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
chenchenplus committed Jan 15, 2025
1 parent f04efb4 commit dfb5aba
Show file tree
Hide file tree
Showing 25 changed files with 1,110 additions and 371 deletions.
43 changes: 39 additions & 4 deletions pycityagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ..environment import Simulator
from ..llm import LLM
from ..memory import Memory
from ..message.messager import Messager
from ..message import MessageInterceptor, Messager
from ..metrics import MlflowClient
from .agent_base import Agent, AgentType

logger = logging.getLogger("pycityagent")
Expand All @@ -32,6 +33,7 @@ def __init__(
memory: Optional[Memory] = None,
economy_client: Optional[EconomyClient] = None,
messager: Optional[Messager] = None, # type:ignore
message_interceptor: Optional[MessageInterceptor] = None, # type:ignore
avro_file: Optional[dict] = None,
) -> None:
super().__init__(
Expand All @@ -40,10 +42,12 @@ def __init__(
llm_client=llm_client,
economy_client=economy_client,
messager=messager,
message_interceptor=message_interceptor,
simulator=simulator,
memory=memory,
avro_file=avro_file,
)
self._mlflow_client = None

async def bind_to_simulator(self):
await self._bind_to_simulator()
Expand Down Expand Up @@ -78,9 +82,7 @@ async def _bind_to_simulator(self):
dict_person[_key] = _value
except KeyError as e:
continue
resp = await simulator.add_person(
dict2pb(dict_person, person_pb2.Person())
)
resp = await simulator.add_person(dict2pb(dict_person, person_pb2.Person()))
person_id = resp["person_id"]
await status.update("id", person_id, protect_llm_read_only_fields=False)
logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
Expand Down Expand Up @@ -123,6 +125,21 @@ async def handle_gather_message(self, payload: dict):
}
await self._send_message(sender_id, payload, "gather")

@property
def mlflow_client(self) -> MlflowClient:
"""The Agent's MlflowClient"""
if self._mlflow_client is None:
raise RuntimeError(
f"MlflowClient access before assignment, please `set_mlflow_client` first!"
)
return self._mlflow_client

def set_mlflow_client(self, mlflow_client: MlflowClient):
"""
Set the mlflow_client of the agent.
"""
self._mlflow_client = mlflow_client


class InstitutionAgent(Agent):
"""
Expand All @@ -137,6 +154,7 @@ def __init__(
memory: Optional[Memory] = None,
economy_client: Optional[EconomyClient] = None,
messager: Optional[Messager] = None, # type:ignore
message_interceptor: Optional[MessageInterceptor] = None, # type:ignore
avro_file: Optional[dict] = None,
) -> None:
super().__init__(
Expand All @@ -145,10 +163,12 @@ def __init__(
llm_client=llm_client,
economy_client=economy_client,
messager=messager,
message_interceptor=message_interceptor,
simulator=simulator,
memory=memory,
avro_file=avro_file,
)
self._mlflow_client = None
# 添加响应收集器
self._gather_responses: dict[str, asyncio.Future] = {}

Expand Down Expand Up @@ -308,3 +328,18 @@ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dic
# 清理Future
for key in futures:
self._gather_responses.pop(key, None)

@property
def mlflow_client(self) -> MlflowClient:
"""The Agent's MlflowClient"""
if self._mlflow_client is None:
raise RuntimeError(
f"MlflowClient access before assignment, please `set_mlflow_client` first!"
)
return self._mlflow_client

def set_mlflow_client(self, mlflow_client: MlflowClient):
"""
Set the mlflow_client of the agent.
"""
self._mlflow_client = mlflow_client
64 changes: 39 additions & 25 deletions pycityagent/agent/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
import fastavro
import ray
from pycityproto.city.person.v2 import person_pb2 as person_pb2
from pyparsing import Dict

from ..economy import EconomyClient
from ..environment import Simulator
from ..environment.sim.person_service import PersonService
from ..llm import LLM
from ..memory import Memory
from ..message.messager import Messager
from ..message import MessageInterceptor, Messager
from ..metrics import MlflowClient
from ..utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
from ..workflow import Block
Expand Down Expand Up @@ -56,7 +55,8 @@ def __init__(
type: AgentType = AgentType.Unspecified,
llm_client: Optional[LLM] = None,
economy_client: Optional[EconomyClient] = None,
messager: Optional[Messager] = None, # type:ignore
messager: Optional[ray.ObjectRef] = None,
message_interceptor: Optional[ray.ObjectRef] = None,
simulator: Optional[Simulator] = None,
memory: Optional[Memory] = None,
avro_file: Optional[dict[str, str]] = None,
Expand All @@ -82,6 +82,7 @@ def __init__(
self._llm_client = llm_client
self._economy_client = economy_client
self._messager = messager
self._message_interceptor = message_interceptor
self._simulator = simulator
self._memory = memory
self._exp_id = -1
Expand All @@ -102,12 +103,12 @@ def __getstate__(self):
return state

@classmethod
def export_class_config(cls) -> Dict[str, Dict]:
def export_class_config(cls) -> dict[str, dict]:
result = {
"agent_name": cls.__name__,
"config": {},
"description": {},
"blocks": []
"blocks": [],
}
config = {
field: cls.default_values.get(field, "default_value")
Expand All @@ -123,16 +124,18 @@ def export_class_config(cls) -> Dict[str, Dict]:
for attr_name, attr_type in hints.items():
if inspect.isclass(attr_type) and issubclass(attr_type, Block):
block_config = attr_type.export_class_config()
result["blocks"].append({
"name": attr_name,
"config": block_config[0],
"description": block_config[1],
"children": cls._export_subblocks(attr_type)
})
result["blocks"].append(
{
"name": attr_name,
"config": block_config[0], # type:ignore
"description": block_config[1], # type:ignore
"children": cls._export_subblocks(attr_type),
}
)
return result

@classmethod
def _export_subblocks(cls, block_cls: type[Block]) -> list[Dict]:
def _export_subblocks(cls, block_cls: type[Block]) -> list[dict]:
children = []
hints = get_type_hints(block_cls) # 获取类的注解
for attr_name, attr_type in hints.items():
Expand All @@ -141,8 +144,8 @@ def _export_subblocks(cls, block_cls: type[Block]) -> list[Dict]:
children.append(
{
"name": attr_name,
"config": block_config[0],
"description": block_config[1],
"config": block_config[0], # type:ignore
"description": block_config[1], # type:ignore
"children": cls._export_subblocks(attr_type),
}
)
Expand Down Expand Up @@ -253,6 +256,12 @@ def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
"""
self._pgsql_writer = pgsql_writer

def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
"""
Set the PostgreSQL copy writer of the agent.
"""
self._message_interceptor = message_interceptor

@property
def uuid(self):
"""The Agent's UUID"""
Expand Down Expand Up @@ -289,24 +298,24 @@ def memory(self):
f"Memory access before assignment, please `set_memory` first!"
)
return self._memory

@property
def status(self):
"""The Agent's Status Memory"""
if self._memory.status is None:
if self.memory.status is None:
raise RuntimeError(
f"Status access before assignment, please `set_memory` first!"
)
return self._memory.status
return self.memory.status

@property
def stream(self):
"""The Agent's Stream Memory"""
if self._memory.stream is None:
if self.memory.stream is None:
raise RuntimeError(
f"Stream access before assignment, please `set_memory` first!"
)
return self._memory.stream
return self.memory.stream

@property
def simulator(self):
Expand Down Expand Up @@ -335,7 +344,7 @@ def messager(self):
async def messager_ping(self):
if self._messager is None:
raise RuntimeError("Messager is not set")
return await self._messager.ping()
return await self._messager.ping.remote() # type:ignore

async def generate_user_survey_response(self, survey: dict) -> str:
"""生成回答 —— 可重写
Expand Down Expand Up @@ -421,7 +430,7 @@ async def _process_survey(self, survey: dict):
_data_tuples
)
)
await self.messager.send_message.remote(f"exps/{self._exp_id}/user_payback", {"count": 1})
await self.messager.send_message.remote(f"exps/{self._exp_id}/user_payback", {"count": 1})# type:ignore

async def generate_user_chat_response(self, question: str) -> str:
"""生成回答 —— 可重写
Expand Down Expand Up @@ -508,7 +517,7 @@ async def _process_interview(self, payload: dict):
_data
)
)
await self.messager.send_message.remote(f"exps/{self._exp_id}/user_payback", {"count": 1})
await self.messager.send_message.remote(f"exps/{self._exp_id}/user_payback", {"count": 1})# type:ignore
print(f"Sent payback message to {self._exp_id}")

async def process_agent_chat_response(self, payload: dict) -> str:
Expand Down Expand Up @@ -579,7 +588,12 @@ async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str)
if self._messager is None:
raise RuntimeError("Messager is not set")
topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
await self._messager.send_message.remote(topic, payload)
await self._messager.send_message.remote( # type:ignore
topic,
payload,
self._uuid,
to_agent_uuid,
)

async def send_message_to_agent(
self, to_agent_uuid: str, content: str, type: str = "social"
Expand Down Expand Up @@ -643,6 +657,6 @@ async def run(self) -> None:
当_blocked为True时,不执行forward方法
"""
if self._messager is not None:
await self._messager.ping.remote()
await self._messager.ping.remote() # type:ignore
if not self._blocked:
await self.forward()
2 changes: 1 addition & 1 deletion pycityagent/cityagent/blocks/plan_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,4 @@ async def forward(self):
def clean_json_response(self, response: str) -> str:
"""Clean special characters in LLM response"""
response = response.replace('```json', '').replace('```', '')
return response.strip()
return response.strip()
99 changes: 99 additions & 0 deletions pycityagent/cityagent/message_intercept.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import asyncio

from pycityagent.llm import LLM
from pycityagent.message import MessageBlockBase, MessageBlockListenerBase


async def check_message(
from_uuid: str, to_uuid: str, llm_client: LLM, content: str
) -> bool:
"""
使用LLM检查消息是否合规
返回: (是否合规, from_uuid, to_uuid)
"""
print(f"\n检查消息: {from_uuid} -> {to_uuid}: {content}")

prompt = f"""
请判断以下消息是否具有情绪煽动性:
发送者ID: {from_uuid}
接收者ID: {to_uuid}
消息内容: {content}
如果消息具有情绪煽动性,请返回 False;如果消息正常,请返回 True。
"""

if "test" in content.lower():
is_valid = False
else:
is_valid = True
print(f"消息检查结果: {'合规' if is_valid else '不合规'}")
return is_valid


class EdgeMessageBlock(MessageBlockBase):
async def forward( # type:ignore
self,
from_uuid: str,
to_uuid: str,
msg: str,
black_list: list[tuple[str, str]],
):
if (from_uuid, to_uuid) in set(black_list):
# 可选同时返回入队的信息(False,err) 如果只返回bool值则默认报错信息入队
return False
else:
is_valid = await check_message(
from_uuid=from_uuid,
to_uuid=to_uuid,
llm_client=self.llm,
content=msg,
)
if not is_valid:
# 直接添加即可 在框架内部的异步锁保证不会冲突
black_list.append((from_uuid, to_uuid))
return is_valid


class PointMessageBlock(MessageBlockBase):
async def forward( # type:ignore
self,
from_uuid: str,
to_uuid: str,
msg: str,
violation_counts: dict[str, int],
black_list: list[tuple[str, str]],
):
if (from_uuid, to_uuid) in set(black_list):
# 可选同时返回入队的信息(False,err) 如果只返回bool值则默认报错信息入队
return False
else:
# violation count在框架内自动维护 这里不用管
is_valid = await check_message(
from_uuid=from_uuid,
to_uuid=to_uuid,
llm_client=self.llm,
content=msg,
)
if not is_valid and violation_counts[from_uuid] >= 3 - 1:
# 直接添加即可 在框架内部的异步锁保证不会冲突
black_list.append((from_uuid, to_uuid))
return is_valid


class MessageBlockListener(MessageBlockListenerBase):
def __init__(
self, save_queue_values: bool = False, get_queue_period: float = 0.1
) -> None:
super().__init__(save_queue_values, get_queue_period)

async def forward(
self,
):
while True:
if self.has_queue:
value = await self.queue.get_async() # type: ignore
if self._save_queue_values:
self._values_from_queue.append(value)
print(f"get `{value}` from queue")
# do something with the value
await asyncio.sleep(self._get_queue_period)
Loading

0 comments on commit dfb5aba

Please sign in to comment.