diff --git a/proconfig/core/block.py b/proconfig/core/block.py index 95918c2c..e78cf475 100755 --- a/proconfig/core/block.py +++ b/proconfig/core/block.py @@ -11,7 +11,7 @@ class Block(BaseModel): type: BlockType = Field(..., description="type of the block") name: str = Field("", description="name of the block") - display_name: str = Field(None, description="the display name of the block") + display_name: str = Field("", description="the display name of the block") properties: Dict[str, Any] = Field({}, description="to specify other properties") inputs: Dict[CustomKey, Union[Input, Value]] = {} outputs: Dict[Union[CustomKey, ContextCustomKey], Union[Variable, Value]] = {} @@ -64,7 +64,7 @@ def sanity_check(self): class BaseProp(BaseModel): - cache: bool = None + cache: bool = False class TaskBase(Block): type: Literal["task"] diff --git a/proconfig/core/common.py b/proconfig/core/common.py index fb83bb72..2753637b 100755 --- a/proconfig/core/common.py +++ b/proconfig/core/common.py @@ -17,7 +17,7 @@ 'transitions', 'states', 'context', - 'payload', + # 'payload', 'condition', 'blocks', 'buttons', diff --git a/proconfig/core/render.py b/proconfig/core/render.py index b9fa02ce..a639e435 100755 --- a/proconfig/core/render.py +++ b/proconfig/core/render.py @@ -1,8 +1,8 @@ from typing import Union, Dict, List, Optional from pydantic import BaseModel, Field -from .common import CustomKey, CustomEventName -from .variables import URLString, Expression, Value +from proconfig.core.common import CustomKey, CustomEventName +from proconfig.core.variables import URLString, Expression, Value class EventPayload(BaseModel): @@ -14,7 +14,7 @@ class Button(BaseModel): content: str = "" description: Optional[str] = Field(default=None, description="Tooltip when hovering button") on_click: Union[CustomEventName, EventPayload] = Field(..., description="event name triggered") - style: Dict[str, str] = None # TODO + style: Dict[str, str] | None = None # TODO class RenderConfig(BaseModel): diff --git a/proconfig/core/state.py b/proconfig/core/state.py index b41d1333..d1a0de51 100755 --- a/proconfig/core/state.py +++ b/proconfig/core/state.py @@ -12,7 +12,7 @@ class StateProp(BaseProp): is_final: bool = False class State(Block): - type: Literal["state"] = "state" + type: Literal["state", "dispatcher"] = "state" properties: StateProp = StateProp() blocks: BlockChildren[Union[Task, Workflow]] = Field({}) render: RenderConfig = RenderConfig() @@ -38,6 +38,7 @@ def sanity_check(self): def check_input_IM(self): IM_count = 0 for k, v in self.inputs.items(): + print(k, v) if v.type in ["text", "string"] and v.source == "IM": IM_count += 1 if IM_count > 1: diff --git a/proconfig/core/variables.py b/proconfig/core/variables.py index d20c1902..06901b4f 100755 --- a/proconfig/core/variables.py +++ b/proconfig/core/variables.py @@ -13,7 +13,7 @@ class VariableBase(BaseModel): type: str - name: str = None + name: str = "" value: Any = None @@ -73,7 +73,7 @@ class InputVariableBase(VariableBase, Generic[T]): default_value: Union[T, Expression] | None = None value: Union[T, Expression] | None = None user_input: bool = False - source: Literal["IM", "form"] = None + source: Literal["IM", "form"] = "form" class InputTextVar(TextVar, InputVariableBase[str], ChoicesBase[str]): type: Literal["text", "string"] diff --git a/proconfig/runners/runner.py b/proconfig/runners/runner.py index 818a8747..92f21059 100755 --- a/proconfig/runners/runner.py +++ b/proconfig/runners/runner.py @@ -220,13 +220,18 @@ def run_block_task(self, task, environ, local_vars): local_vars[task.name] = outputs def process_comfy_extra_inputs(self, task): - comfy_extra_inputs = { "api": task.api, "comfy_workflow_id": task.comfy_workflow_id, "location": task.location, } task.inputs["comfy_extra_inputs"] = comfy_extra_inputs + + def process_myshell_extra_inputs(self, task): + task.inputs = { + "widget_id": task.widget_name.split("/")[1], + "inputs": task.inputs + } def run_task(self, container, task, environ, local_vars): @@ -234,6 +239,8 @@ def run_task(self, container, task, environ, local_vars): if task.widget_class_name == "ComfyUIWidget": # hasattr(task, "comfy_workflow_id") assert hasattr(task, "comfy_workflow_id"), "no comfy_workflow_id is founded" self.process_comfy_extra_inputs(task) + elif task.widget_class_name == "MyShellAnyWidgetCallerWidget": + self.process_myshell_extra_inputs(task) if task.mode in ["widget", "comfy_workflow"]: return self.run_widget_task(container, task, environ, local_vars) @@ -564,7 +571,7 @@ def run_automata(self, automata: Automata, sess_state: SessionState, payload: di local_transitions = automata.blocks[current_state_name].transitions or {} global_transitions = automata.transitions or {} - events = [] + events = {} for button_id, button in enumerate(render.get("buttons", [])): event_name = button["on_click"] if isinstance(event_name, dict): @@ -572,19 +579,32 @@ def run_automata(self, automata: Automata, sess_state: SessionState, payload: di event_name = event_name["event"] else: event_payload = {} - events.append({ - "event_key": f"BUTTON_{button_id}", + event_key = f"BUTTON_{button_id}" + events[event_key] = { "event_name": event_name, - "payload": event_payload - }) + "payload": event_payload, + "is_button": True + } - events.append({ - "event_key": "CHAT", - "event_name": "CHAT", - "payload": {} - }) + events["CHAT"] = { + "event_name": "CHAT", + "payload": {} + } + + skip_evaluate_events = ["X.MENTIONED", "X.PINNED.REPLIED"] + + # final event mapping, not target_inputs + for key, transition in local_transitions.items(): + # evaluate the transitions + if key not in events: + events[key] = { + "event_name": key, + "payload": {} + } + if key in skip_evaluate_events: + events[key]["skip_evaluate_target_inputs"] = True - for event_item in events: + for event_key, event_item in events.items(): event_name = event_item["event_name"] event_payload = event_item["payload"] for payload_key, payload_value in event_payload.items(): @@ -609,7 +629,10 @@ def run_automata(self, automata: Automata, sess_state: SessionState, payload: di # get the target_inputs from transition target_inputs_transition = {} for k, v in transition_case.target_inputs.items(): - target_inputs_transition[k] = calc_expression(v, {'payload': event_payload, **local_vars}) # the target_inputs defined by the transition + if event_item.get("skip_evaluate_target_inputs", False): + target_inputs_transition[k] = v + else: + target_inputs_transition[k] = calc_expression(v, {'payload': event_payload, **local_vars}) # the target_inputs defined by the transition # get the target inputs target_inputs = automata.blocks[target_state].inputs @@ -625,14 +648,16 @@ def run_automata(self, automata: Automata, sess_state: SessionState, payload: di setattr(input_var, k, tree_map(lambda x: calc_expression(x, visible_variables), getattr(input_var, k))) # target_inputs = tree_map(lambda x: calc_expression(x, local_vars), target_inputs) - event_mapping[event_item["event_key"]] = EventItem(**{ + event_mapping[event_key] = EventItem(**{ "target_state": target_state, "target_inputs": target_inputs, "target_inputs_transition": target_inputs_transition, # the target_inputs from the transition }) + + message_count = sess_state.message_count - event_mapping = {f'MESSAGE_{message_count}_{k}' if k != "CHAT" else k: v for k, v in event_mapping.items()} + event_mapping = {f'MESSAGE_{message_count}_{k}' if events[k].get("is_button", False) else k: v for k, v in event_mapping.items()} sess_state.event_mapping.update(event_mapping) return sess_state, render \ No newline at end of file diff --git a/proconfig/utils/x_bot.py b/proconfig/utils/x_bot.py new file mode 100644 index 00000000..ac2a4997 --- /dev/null +++ b/proconfig/utils/x_bot.py @@ -0,0 +1,141 @@ +import json +from pydantic import BaseModel +from typing import List, Literal +from proconfig.core import Automata, State +from proconfig.core.automata import StateWithTransitions +from proconfig.core.variables import InputTextVar +from proconfig.core.common import TransitionCase +from proconfig.core.render import Button +from proconfig.utils.misc import tree_map +from datetime import date + + +class TwitterUser(BaseModel): + creation_date: str = str(date.today()) + user_id: str = "user_id_123" + username: str = "username_default" + name: str = "Default Name" + is_private: bool = False + is_verified: bool = False + is_blue_verified: bool = False + bot: bool = False + verified_type: Literal['blue', 'business', 'government'] = 'blue' + + +class TweetDetail(BaseModel): + tweet_id: str = "tweet_id_123" + creation_date: str = str(date.today()) + text: str = "This is an example text" + user: TwitterUser = TwitterUser() + media_url: List[str] = [] + + +x_special_events = { + "X.MENTIONED": { + "button_name": "Reply to mentioned", + }, + "X.PINNED.REPLIED": { + "button_name": "Reply to pinned tweet" + }, +} + + # content: str = "" + # description: Optional[str] = Field(default=None, description="Tooltip when hovering button") + # on_click: Union[CustomEventName, EventPayload] = Field(..., description="event name triggered") + # style: Dict[str, str] = None # TODO + +def replace_payload_to_payload1(x): + if type(x) == str: + x = x.replace("payload.", "payload1.") + return x + +def convert_x_bot_to_automata(automata_x): + # build default_payload + default_payload = TweetDetail().model_dump_json(indent=2) + # user = default_payload["user"] + # default_payload["user"] = f"{{{{{user}}}}}" + + # automata_x = Automata.model_validate(automata_x) + + + new_states = {} + automata_x["blocks"] = automata_x.get("blocks", {}) + for state_name, state in automata_x["blocks"].items(): + timers = {} + state["properties"] = state.get("properties", {}) + if "timers" in state["properties"]: + for timer_key, timer_info in state["properties"]["timers"].items(): + timers[timer_info["event"]] = timer_key + + state["render"] = state.get("render", {}) + buttons = state["render"].get("buttons", []) + # buttons = state.render.buttons + # print("buttons:", buttons) + # first, handle the special events + count = 0 + state["transitions"] = state.get("transitions") or {} + for event_name, transition_cases in state["transitions"].items(): + if event_name in x_special_events: + new_button = { + "content": x_special_events[event_name]["button_name"], + "on_click": { + "event": event_name, + "payload": {} + } + } + # add a pseudo state + new_state = { + "display_name": "", + "inputs": { + "payload_str": { + "type": "text", + "name": "payload_str", + "default_value": default_payload, + "user_input": True, + "source": "form" + } + }, + "outputs": { + "payload1": "{{json.loads(payload_str)}}" + }, + "transitions": { + "ALWAYS": tree_map(replace_payload_to_payload1, transition_cases) + } + } + new_state_name = f"{state_name}_event_{count}" + # state.transitions[event_name] = TransitionCase(target=new_state_name) + # new_states[new_state_name] = StateWithTransitions.model_validate(new_state) + + state["transitions"][event_name] = { + "target": new_state_name + } + new_states[new_state_name] = new_state + elif event_name in timers: + new_button = { + "content": f"Timer [{timers[event_name]}] triggered", + "on_click": { + "event": event_name, + "payload": {} + } + } + else: + continue + count += 1 + buttons.append(new_button) + # buttons.append(Button.model_validate(new_button)) + state["render"]["buttons"] = buttons + automata_x["blocks"].update(new_states) + return automata_x + return automata_x.model_dump() + + +if __name__ == "__main__": + automata_x_hello = json.load(open("temp/automata_x.json")) + automata_hello = convert_x_bot_to_automata(automata_x_hello) + # import pdb; pdb.set_trace() + # InputTextVar.model_validate(automata_hello["blocks"]["idle_event_0"]["inputs"]["payload_str"]) + # automata_hello_load = Automata.model_validate(automata_hello) + json.dump(automata_hello, open("temp/automata_x_converted.json", "w"), indent=2) + + + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/proconfig/widgets/imagen_widgets/comfyui_widget.py b/proconfig/widgets/imagen_widgets/comfyui_widget.py index a8d9a8e6..ffcde839 100644 --- a/proconfig/widgets/imagen_widgets/comfyui_widget.py +++ b/proconfig/widgets/imagen_widgets/comfyui_widget.py @@ -383,7 +383,16 @@ def comfyui_run_myshell(workflow_id, inputs, extra_headers): response = requests.post(url, headers=headers, json=data) # Parse the JSON response if response.status_code == 200: - return json.loads(response.json()['result']) + response_json = response.json() + if response_json["success "]: + return json.loads(response.json()['result']) + else: + error = { + 'error_code': 'SHELL-1109', + 'error_head': 'HTTP Request Error', + 'msg': response_json["error_message"], + } + raise ShellException(**error) else: try: response_content = response.json() diff --git a/proconfig/widgets/language_models/claude_widgets.py b/proconfig/widgets/language_models/claude_widgets.py index c39a2650..cfe97a2a 100755 --- a/proconfig/widgets/language_models/claude_widgets.py +++ b/proconfig/widgets/language_models/claude_widgets.py @@ -45,6 +45,7 @@ class MemoryItem(BaseModel): class ClaudeWidget(BaseWidget): NAME = "Claude 3.5 Sonnet" CATEGORY = "Large Language Model/Claude" + MYSHELL_WIDGET_NAME = "@myshell_llm/1744218088699596812" class InputsSchema(BaseWidget.InputsSchema): system_prompt: str = "" diff --git a/proconfig/widgets/myshell_widgets/__init__.py b/proconfig/widgets/myshell_widgets/__init__.py index cb961f94..5ea96e11 100755 --- a/proconfig/widgets/myshell_widgets/__init__.py +++ b/proconfig/widgets/myshell_widgets/__init__.py @@ -1,4 +1,5 @@ from proconfig.widgets.myshell_widgets.tools.image_text_fuser import ImageTextFuserWidget from proconfig.widgets.myshell_widgets.tools.image_canvas import ImageCanvasWidget from proconfig.widgets.myshell_widgets.tools.twitter_search import XWidget -from proconfig.widgets.myshell_widgets.tools.html2img import Html2ImgWidget \ No newline at end of file +from proconfig.widgets.myshell_widgets.tools.html2img import Html2ImgWidget +from proconfig.widgets.myshell_widgets.myshell_widget_caller import MyShellAnyWidgetCallerWidget \ No newline at end of file diff --git a/proconfig/widgets/myshell_widgets/myshell_widget_caller.py b/proconfig/widgets/myshell_widgets/myshell_widget_caller.py new file mode 100644 index 00000000..49f3c169 --- /dev/null +++ b/proconfig/widgets/myshell_widgets/myshell_widget_caller.py @@ -0,0 +1,90 @@ +from typing import Any, Literal, Optional, List +from pydantic import Field, BaseModel + +from proconfig.widgets.base import BaseWidget, WIDGETS +from proconfig.core.exception import ShellException + +# import instructor +import os +from pydantic import BaseModel, Field +import requests +import json + +@WIDGETS.register_module() +class MyShellAnyWidgetCallerWidget(BaseWidget): + NAME = "MyShellAnyWidgetCallerWidget" + # CATEGORY = "Myshell Widgets/Widget Caller" + dynamic_schema = True + + class InputsSchema(BaseWidget.InputsSchema): + widget_id: str = Field(..., description="the widget id to call") + inputs: dict = {} + + class OutputsSchema(BaseWidget.OutputsSchema): # useless + data: str | list + + def execute(self, environ, config): + # API endpoint URL + url = "https://openapi.myshell.ai/public/v1/widget/run" + widget_id = config["widget_id"] + # no myshell-test + + # Headers for the API request + headers = { + "x-myshell-openapi-key": os.environ["MYSHELL_API_KEY"], + "Content-Type": "application/json", + **environ.get("MYSHELL_HEADERS", {}) + } + + # Request payload + data = { + "widget_id": widget_id, + "input": json.dumps(config["inputs"]) + } + + # Send POST request to the API + response = requests.post(url, headers=headers, json=data) + + # Parse the JSON response + json_response = response.json() + + # Extract the 'result' field and return it as a string + if json_response.get('success') and 'result' in json_response: + result = json.loads(json_response['result']) + # handle the _url + if "_url" in result: + response = requests.get(result["_url"]) + if response.status_code == 200: + return response.json() + else: + return {"error_message": "error when retrieve the result"} + return result + else: + error = { + 'error_code': 'SHELL-1102', + 'error_head': 'Widget Execution Error', + 'msg': f'widget {widget_id} failed to execute', + 'traceback': json.dumps(json_response), + } + exception = ShellException(**error) + raise exception + + +if __name__ == "__main__": + from dotenv import load_dotenv + load_dotenv() + caller = MyShellAnyWidgetCallerWidget() + config = { + "widget_id": "1800948886988423168", + "inputs": { + "cfg_scale": 3.5, + "height": 1024, + "output_name": "result", + "prompt": "a beautiful girl", + "width": 1024 + } + } + config = caller.InputsSchema.model_validate(config) + + outputs = caller({}, config) + import pdb; pdb.set_trace() diff --git a/proconfig/widgets/myshell_widgets/tools/html2img.py b/proconfig/widgets/myshell_widgets/tools/html2img.py index 4dba3ff9..80657582 100644 --- a/proconfig/widgets/myshell_widgets/tools/html2img.py +++ b/proconfig/widgets/myshell_widgets/tools/html2img.py @@ -12,7 +12,8 @@ @WIDGETS.register_module() class Html2ImgWidget(BaseWidget): NAME = "HTML to Image" - CATEGORY = "Myshell Widgets/Tools" + CATEGORY = "Image Processing/HTML to Image" + MYSHELL_WIDGET_NAME = '@myshell/1850127954369142784' class InputsSchema(BaseWidget.InputsSchema): html_str: str = Field(default="", description="The HTML string to convert to an image") diff --git a/proconfig/widgets/myshell_widgets/tools/image_canvas.py b/proconfig/widgets/myshell_widgets/tools/image_canvas.py index 8c4215f3..11b36133 100644 --- a/proconfig/widgets/myshell_widgets/tools/image_canvas.py +++ b/proconfig/widgets/myshell_widgets/tools/image_canvas.py @@ -132,7 +132,7 @@ def get_next_image_filename(directory, prefix="ImageCanvas_", suffix=".png"): @WIDGETS.register_module() class ImageCanvasWidget(BaseWidget): NAME = "Image Canvas" - CATEGORY = 'Myshell Widgets/Tools' + CATEGORY = 'Image Processing/Image Canvas' class InputsSchema(BaseWidget.InputsSchema): config: str diff --git a/proconfig/widgets/myshell_widgets/tools/image_text_fuser.py b/proconfig/widgets/myshell_widgets/tools/image_text_fuser.py index 7b1c3981..1e2d3c1b 100755 --- a/proconfig/widgets/myshell_widgets/tools/image_text_fuser.py +++ b/proconfig/widgets/myshell_widgets/tools/image_text_fuser.py @@ -326,7 +326,7 @@ def add_content_to_image(template_path, content): @WIDGETS.register_module() class ImageTextFuserWidget(BaseWidget): NAME = "Image Text Fuser" - CATEGORY = 'Myshell Widgets/Tools' + # CATEGORY = 'Myshell Widgets/Tools' class InputsSchema(BaseWidget.InputsSchema): config: str = "" diff --git a/proconfig/widgets/myshell_widgets/tools/twitter_search.py b/proconfig/widgets/myshell_widgets/tools/twitter_search.py index d7eeae27..46247b4c 100644 --- a/proconfig/widgets/myshell_widgets/tools/twitter_search.py +++ b/proconfig/widgets/myshell_widgets/tools/twitter_search.py @@ -12,7 +12,8 @@ @WIDGETS.register_module() class XWidget(BaseWidget): NAME = "Twitter Search" - CATEGORY = "Myshell Widgets/Tools" + CATEGORY = "Tools/Twitter Search" + MYSHELL_WIDGET_NAME = '@myshell/1784206090390036480' class InputsSchema(BaseWidget.InputsSchema): action: Literal["search_tweets", "scrape_tweets", "scrape_profile"] = Field("scrape_tweets", description="The action to perform") diff --git a/proconfig/widgets/tools/code_runner.py b/proconfig/widgets/tools/code_runner.py index 8dfbe716..31e1abc8 100644 --- a/proconfig/widgets/tools/code_runner.py +++ b/proconfig/widgets/tools/code_runner.py @@ -5,7 +5,7 @@ @WIDGETS.register_module() class CodeRunnerWidget(BaseWidget): NAME = "Code Runner" - CATEGORY = "Tools/Code Runner" + # CATEGORY = "Tools/Code Runner" # hidden for now class InputsSchema(BaseWidget.InputsSchema): language: Literal["python", "javascript"] = "python" diff --git a/servers/automata.py b/servers/automata.py index b574eade..eb4854b6 100755 --- a/servers/automata.py +++ b/servers/automata.py @@ -18,8 +18,10 @@ from proconfig.widgets.tools.dependency_checker import check_dependency from proconfig.core import Automata +from proconfig.utils.expressions import calc_expression from proconfig.runners.runner import Runner from proconfig.utils.misc import convert_unserializable_display, process_local_file_path_async +from proconfig.utils.x_bot import convert_x_bot_to_automata from proconfig.utils.pytree import tree_map from proconfig.core.chat import ( ServerMessage, MessageComponentsContainer, MessageComponentsButton, @@ -404,7 +406,9 @@ class RuntimeData(BaseModel): trace_id_to_runtime_data_map: Dict[str, RuntimeData] = {} class MyShellUserInput(BaseModel): - type: str + type: str # "1": TEXT, "2": VOICE, "15": BUTTON_INTERACTION, "-1": 初始值, + # "20": X.MENTIONED, "21": "X.PINNED.REPLIED", "22": timer + form: Dict[str, Any] = {} button_id: str = "" text: str = "" @@ -412,21 +416,30 @@ class MyShellUserInput(BaseModel): def prepare_payload(automata: Automata, event_data: MyShellUserInput, sess_state: SessionState): # Decide the event_name - if event_data.type == "15": - event_name = event_data.button_id - elif event_data.type == "1": - event_name = "CHAT" - # Build the form data - event_data.form[CHAT_MESSAGE] = event_data.text - if len(event_data.embeded_objects) > 0: - # collect images - chat_images = [ - item["url"] for item in event_data.embeded_objects if item["type"] == "MESSAGE_METADATA_TYPE_IMAGE_FILE" - ] - if len(chat_images) > 0: - event_data.form[CHAT_IMAGES] = chat_images - else: - event_name = None + match event_data.type: + case "15": + event_name = event_data.button_id + case "1": + event_name = "CHAT" + # Build the form data + event_data.form[CHAT_MESSAGE] = event_data.text + if len(event_data.embeded_objects) > 0: + # collect images + chat_images = [ + item["url"] for item in event_data.embeded_objects if item["type"] == "MESSAGE_METADATA_TYPE_IMAGE_FILE" + ] + if len(chat_images) > 0: + event_data.form[CHAT_IMAGES] = chat_images + case "20" | "21": + event_name = "X.MENTIONED" if event_data.type == "20" else "X.PINNED.REPLIED" + ctx = {"payload": event_data.form} + target_inputs = getattr(sess_state.event_mapping[event_name], "target_inputs_transition", {}) + event_data.form = tree_map(lambda x: calc_expression(x, ctx), target_inputs) + case "22": + event_name = event_data.button_id + pass + case _: + event_name = None print("event_name:", event_name) target_state = automata.initial if event_name is None else sess_state.event_mapping[event_name].target_state @@ -873,7 +886,6 @@ async def run_automata_stateless(request: MyShellRunAppRequest): thread.start() thread.join() result = result_queue.get() - # result = run_automata_stateless_impl(request) return result @@ -907,4 +919,15 @@ async def get_intro_display(data: Dict): "text": data['server_message']['text'], "images": [item["url"] for item in data['server_message']['embedObjs'] if item["type"] == "MESSAGE_METADATA_TYPE_IMAGE_FILE"] } - return return_dict \ No newline at end of file + return return_dict + + +@app.post('/api/app/convert_x_bot') +async def convert_x_bot(data: Dict): + # convert x bot to normal automata + automata_x = data["automata_x"] + automata = convert_x_bot_to_automata(automata_x) + return { + "automata": automata + } + \ No newline at end of file diff --git a/servers/main.py b/servers/main.py index fe11f0d1..58f3fb87 100755 --- a/servers/main.py +++ b/servers/main.py @@ -23,6 +23,7 @@ import servers.settings import servers.comfy_runner import servers.helper + import servers.widgets diff --git a/servers/widgets.py b/servers/widgets.py new file mode 100644 index 00000000..42ed79fc --- /dev/null +++ b/servers/widgets.py @@ -0,0 +1,116 @@ +import json +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from typing import Dict +from servers.base import app, APP_SAVE_ROOT, WORKFLOW_SAVE_ROOT, APP_RUNS_SAVE_ROOT, PROJECT_ROOT + +from proconfig.widgets.base import WIDGETS + +myshell_widget_list = json.load(open("assets/myshell_widget_list.json")) +@app.get("/api/widget/get_myshell_widget_list") +async def get_myshell_widget_list(): + # collect usage + usage_to_widget_map = {} + for widget_id, item in myshell_widget_list.items(): + item["widget_id"] = widget_id + usage = item["usage"] + if usage not in usage_to_widget_map: + usage_to_widget_map[usage] = [] + usage_to_widget_map[item["usage"]].append(item) + + data = [] + for usage, usage_widget_list in usage_to_widget_map.items(): + usage_data = { + "title": " ".join(word.capitalize() for word in usage.split("_")), + "items": [] + } + for widget_info in usage_widget_list: + widget_item = { + "name": "MyShellAnyWidgetCallerWidget", + "display_name": widget_info["name"], + "type": "widget", + "widget_name": f"@myshell/{widget_info['widget_id']}" + } + usage_data["items"].append(widget_item) + data.append(usage_data) + + response = { + "widget_list": data + } + return response + + +# @app.get("/api/widget/get_myshell_widget_schema") +def get_myshell_widget_schema(widget_id): + # widget_id = data["widget_id"] + response = { + "input_schema": myshell_widget_list[widget_id]["inputs"], + "output_schema": myshell_widget_list[widget_id]["outputs"], + } + return response + +@app.post("/api/workflow/get_widget_schema") +async def get_widget_schema(data: Dict[str, str]) -> Dict: + widget_name = data["widget_name"] + if widget_name == "MyShellAnyWidgetCallerWidget": + widget_id = data["myshell_widget_name"].split("/")[1] + print("get widget id:", widget_id) + return get_myshell_widget_schema(widget_id) + + WidgetClass = WIDGETS.module_dict[widget_name] + + response = {} + if len(WidgetClass.InputsSchemaDict) > 0: # multiple inputs schema + response["input_schema"] = { + k: v.model_json_schema() + for k, v in WidgetClass.InputsSchemaDict.items() + } + response["multi_input_schema"] = True + else: + response["input_schema"] = WidgetClass.InputsSchema.model_json_schema() + response["multi_input_schema"] = False + + response["output_schema"] = WidgetClass.OutputsSchema.model_json_schema() + + return response + + +@app.get("/api/widget/get_local_widget_list") +async def get_widget_category_list(): + # Step 1: Get categories + usage_to_categories_map = {} + + for category, widget_names in WIDGETS.category_dict.items(): + usage_name = category.split('/')[0] + category_name = category.split('/')[1] if isinstance(category, str) else category.value + if category_name == '': + continue + if usage_name not in usage_to_categories_map: + usage_to_categories_map[usage_name] = [] + + sub_widget_items = [ + { + "name": widget_name, + "display_name": WIDGETS._name_mapping.get(widget_name, "Warning: Display Name not Specified"), + "widget_name": getattr(WIDGETS._module_dict[widget_name], "MYSHELL_WIDGET_NAME", "") + } + for widget_name in widget_names + ] + + usage_to_categories_map[usage_name].append( + { + "name": category_name, + "children": sub_widget_items + } + ) + + response = { + "widget_list": [ + { + "title": usage, + "items": usage_to_categories_map[usage] + } + for usage in usage_to_categories_map + ] + } + return response \ No newline at end of file diff --git a/servers/workflow.py b/servers/workflow.py index c7535d85..cf083d46 100755 --- a/servers/workflow.py +++ b/servers/workflow.py @@ -84,25 +84,25 @@ async def get_widgets_by_category(data: Dict) -> Dict: } return response -@app.post("/api/workflow/get_widget_schema") -async def get_widget_schema(data: Dict[str, str]) -> Dict: - widget_name = data["widget_name"] - WidgetClass = WIDGETS.module_dict[widget_name] +# @app.post("/api/workflow/get_widget_schema") +# async def get_widget_schema(data: Dict[str, str]) -> Dict: +# widget_name = data["widget_name"] +# WidgetClass = WIDGETS.module_dict[widget_name] - response = {} - if len(WidgetClass.InputsSchemaDict) > 0: # multiple inputs schema - response["input_schema"] = { - k: v.model_json_schema() - for k, v in WidgetClass.InputsSchemaDict.items() - } - response["multi_input_schema"] = True - else: - response["input_schema"] = WidgetClass.InputsSchema.model_json_schema() - response["multi_input_schema"] = False +# response = {} +# if len(WidgetClass.InputsSchemaDict) > 0: # multiple inputs schema +# response["input_schema"] = { +# k: v.model_json_schema() +# for k, v in WidgetClass.InputsSchemaDict.items() +# } +# response["multi_input_schema"] = True +# else: +# response["input_schema"] = WidgetClass.InputsSchema.model_json_schema() +# response["multi_input_schema"] = False - response["output_schema"] = WidgetClass.OutputsSchema.model_json_schema() +# response["output_schema"] = WidgetClass.OutputsSchema.model_json_schema() - return response +# return response @app.post("/api/workflow/save")