Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 committed Jan 14, 2025
1 parent edda81b commit 7b4e399
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
22 changes: 12 additions & 10 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class Engine(ABC):
REPORT_URL = ""

def __init__(self):
self._nodes = {'__start__': Node(id='__start__', kind='__start__', name='__start__'),
'__end__': Node(id='__end__', kind='__end__', name='__end__')}
self._nodes: Dict[str, Node] = {
'__start__': Node(id='__start__', kind='__start__', name='__start__'),
'__end__': Node(id='__end__', kind='__end__', name='__end__')}

def __new__(cls):
if cls is not Engine:
Expand Down Expand Up @@ -103,13 +104,14 @@ def impl(f):
def build(self, node: Node):
if node.kind.startswith('__') and node.kind.endswith('__'):
return None
node.arg_names = node.args.pop('_lazyllm_arg_names', None) if isinstance(node.args, dict) else None
node.enable_data_reflow = (node.args.pop('_lazyllm_enable_report', False)
if isinstance(node.args, dict) else False)
node_args = node.args.copy()
node.arg_names = node_args.pop('_lazyllm_arg_names', None) if isinstance(node_args, dict) else None
node.enable_data_reflow = (node_args.pop('_lazyllm_enable_report', False)
if isinstance(node_args, dict) else False)
if node.kind in NodeConstructor.builder_methods:
createf, node.subitem_name = NodeConstructor.builder_methods[node.kind]
node.func = createf(**node.args) if isinstance(node.args, dict) and set(node.args.keys()).issubset(
set(inspect.getfullargspec(createf).args)) else createf(node.args)
node.func = createf(**node_args) if isinstance(node_args, dict) and set(node_args.keys()).issubset(
set(inspect.getfullargspec(createf).args)) else createf(node_args)
self._process_hook(node, node.func)
return node

Expand All @@ -122,7 +124,7 @@ def get_args(cls, key, value, builder_key=None):
return Engine().build_node(value).func
return node_args.getattr_f(value) if node_args.getattr_f else value

for key, value in node.args.items():
for key, value in node_args.items():
if key in node_msgs['init_arguments']:
init_args[key] = get_args('init_arguments', key, value)
elif key in node_msgs['builder_argument']:
Expand Down Expand Up @@ -420,9 +422,9 @@ def make_http_tool(method: Optional[str] = None,
vars_for_code: Optional[Dict[str, Any]] = None,
doc: Optional[str] = None,
outputs: Optional[List[str]] = None,
extract_from_output: Optional[bool] = None):
extract_from_result: Optional[bool] = None):
instance = lazyllm.tools.HttpTool(method, url, params, headers, body, timeout, proxies,
code_str, vars_for_code, outputs, extract_from_output)
code_str, vars_for_code, outputs, extract_from_result)
if doc:
instance.__doc__ = doc
return instance
Expand Down
12 changes: 7 additions & 5 deletions lazyllm/tools/tools/http_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ def __init__(self,
code_str: Optional[str] = None,
vars_for_code: Optional[Dict[str, Any]] = None,
outputs: Optional[List[str]] = None,
extract_from_output: Optional[bool] = None):
extract_from_result: Optional[bool] = None):
super().__init__(method, url, '', headers, params, body, timeout, proxies)
self._has_http = True if url else False
self._compiled_func = compile_func(code_str, vars_for_code) if code_str else (lambda x: json.loads(x['content']))
self._outputs, self._extract_from_output = outputs, extract_from_output
if extract_from_output:
self._outputs, self._extract_from_result = outputs, extract_from_result
if extract_from_result:
assert outputs, 'Output information is necessary to extract output parameters'
assert len(outputs) == 1, 'When the number of outputs is greater than 1, no manual setting is required'

def _get_result(self, res):
if self._extract_from_output or (isinstance(res, dict) and len(self._outputs) > 1):
if self._extract_from_result or (isinstance(res, dict) and len(self._outputs) > 1):
assert isinstance(res, dict), 'The result of the tool should be a dict type'
return package(res.get(key) for key in self._outputs)
r = package(res.get(key) for key in self._outputs)
return r[0] if len(r) == 1 else r
if len(self._outputs) > 1:
assert isinstance(res, (tuple, list)), 'The result of the tool should be tuple or list'
assert len(res) == len(self._outputs), 'The number of outputs is inconsistent with expectations'
Expand Down
39 changes: 39 additions & 0 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,45 @@ def test_engine_httptool(self):
assert res['headers']['h1'] == 'baz'
assert res['url'].endswith(f'{url[5:]}?p1=foo&p2=bar')

def test_engine_httptool_with_output(self):
params = {'p1': '{{p1}}', 'p2': '{{p2}}'}
headers = {'h1': '{{h1}}'}
url = 'https://postman-echo.com/get'

nodes = [
dict(id='0', kind='Code', name='code1', args=dict(code='def p1(): return "foo"')),
dict(id='1', kind='Code', name='code2', args=dict(code='def p2(): return "bar"')),
dict(id='2', kind='Code', name='code3', args=dict(code='def h1(): return "baz"')),
dict(id='3', kind='HttpTool', name='http', args=dict(
method='GET', url=url, params=params, headers=headers,
outputs=['headers', 'url'], _lazyllm_arg_names=['p1', 'p2', 'h1']))
]
edges = [dict(iid='__start__', oid='0'), dict(iid='__start__', oid='1'), dict(iid='__start__', oid='2'),
dict(iid='0', oid='3'), dict(iid='1', oid='3'), dict(iid='2', oid='3'), dict(iid='3', oid='__end__')]

engine = LightEngine()
gid = engine.start(nodes, edges, gid='graph-1')
res = engine.run(gid)

assert isinstance(res, lazyllm.package) and len(res) == 2
assert res[0]['h1'] == 'baz'
assert res[1].endswith(f'{url[5:]}?p1=foo&p2=bar')

engine.reset()

nodes[3]['args']['outputs'] = ['output']
gid = engine.start(nodes, edges)
res = engine.run(gid)
assert res['headers']['h1'] == 'baz'

engine.reset()

nodes[3]['args']['outputs'] = ['headers']
nodes[3]['args']['extract_from_result'] = True
gid2 = engine.start(nodes, edges)
res = engine.run(gid2)
assert res['h1'] == 'baz'

def test_engine_httptool_body(self):
body = {'b1': '{{b1}}', 'b2': '{{b2}}'}
headers = {'Content-Type': '{{h1}}'}
Expand Down

0 comments on commit 7b4e399

Please sign in to comment.