From ac207b604e11f43801083b10c2dc24d1fae748a4 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Wed, 14 Aug 2024 01:03:46 -0700 Subject: [PATCH] cr --- .../multimodal_report_generation.ipynb | 3 +- .../multimodal_report_generation_agent.ipynb | 182 ++++++++---------- 2 files changed, 86 insertions(+), 99 deletions(-) diff --git a/examples/multimodal/multimodal_report_generation.ipynb b/examples/multimodal/multimodal_report_generation.ipynb index 3ec004a..a789bf4 100644 --- a/examples/multimodal/multimodal_report_generation.ipynb +++ b/examples/multimodal/multimodal_report_generation.ipynb @@ -710,8 +710,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/examples/multimodal/multimodal_report_generation_agent.ipynb b/examples/multimodal/multimodal_report_generation_agent.ipynb index 62be289..24d810b 100644 --- a/examples/multimodal/multimodal_report_generation_agent.ipynb +++ b/examples/multimodal/multimodal_report_generation_agent.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "70ccdd53-e68a-4199-aacb-cfe71ad1ff0b", "metadata": {}, "outputs": [], @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "16e2071d-bbc2-4707-8ae7-cb4e1fecafd3", "metadata": {}, "outputs": [], @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "8bce3407-a7d2-47e8-9eaf-ab297a94750c", "metadata": {}, "outputs": [], @@ -166,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "3b3fa3bd-c70f-45d5-9377-d81be8160891", "metadata": {}, "outputs": [], @@ -224,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "18c24174-05ce-417f-8dd2-79c3f375db03", "metadata": {}, "outputs": [], @@ -235,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "8e331dfe-a627-4e23-8c57-70ab1d9342e4", "metadata": {}, "outputs": [], @@ -261,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "346fe5ef-171e-4a54-9084-7a7805103a13", "metadata": {}, "outputs": [], @@ -299,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "f591669c-5a8e-491d-9cef-0b754abbf26f", "metadata": {}, "outputs": [], @@ -318,11 +318,9 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "a4a85b24-a87d-468f-b235-da4f8b520d96", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# SAVE\n", @@ -333,11 +331,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "36d6a888-f47d-4882-9dbd-b9a2b95bed8a", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# LOAD\n", @@ -348,11 +344,9 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "f9c5ae4b-5a13-4c71-9fe4-fdeb278bfaba", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "all_text_nodes = []\n", @@ -362,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "id": "32c13950-c1db-435f-b5b4-89d62b8b7744", "metadata": {}, "outputs": [ @@ -429,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "id": "bee7bc2e-be3d-4fd3-a1df-b2dcaa66c404", "metadata": {}, "outputs": [], @@ -454,11 +448,10 @@ " # load index\n", " index = load_index_from_storage(storage_context, index_id=\"vector_index\")\n", "\n", - " \n", + "\n", "# Summary Index dictionary - store map from paper path to a summary index around it\n", "paper_summary_indexes = {\n", - " paper_path: SummaryIndex(text_nodes_dict[paper_path])\n", - " for paper_path in papers\n", + " paper_path: SummaryIndex(text_nodes_dict[paper_path]) for paper_path in papers\n", "}" ] }, @@ -474,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "id": "d1114bdc-3a2a-4703-b367-80e248469a0a", "metadata": {}, "outputs": [], @@ -483,51 +476,58 @@ "from llama_index.core.schema import NodeWithScore\n", "from typing import List\n", "\n", + "\n", "# function tools\n", "def chunk_retriever_fn(query: str) -> List[NodeWithScore]:\n", " \"\"\"Retrieves a small set of relevant document chunks from the corpus.\n", - " \n", + "\n", " ONLY use for research questions that want to look up specific facts from the knowledge corpus,\n", " and don't need entire documents.\n", - " \n", + "\n", " \"\"\"\n", " retriever = index.as_retriever(similarity_top_k=5)\n", " nodes = retriever.retrieve(query)\n", " return nodes\n", "\n", "\n", - "\n", - "def _get_document_nodes(nodes: List[NodeWithScore], top_n: int = 2) -> List[NodeWithScore]:\n", + "def _get_document_nodes(\n", + " nodes: List[NodeWithScore], top_n: int = 2\n", + ") -> List[NodeWithScore]:\n", " \"\"\"Get document nodes from a set of chunk nodes.\n", - " \n", + "\n", " Given chunk nodes, \"de-reference\" into a set of documents, with a simple weighting function (cumulative total) to determine ordering.\n", - " \n", + "\n", " Cutoff by top_n.\n", - " \n", + "\n", " \"\"\"\n", " paper_paths = {n.metadata[\"paper_path\"] for n in nodes}\n", " paper_path_scores = {f: 0 for f in paper_paths}\n", " for n in nodes:\n", " paper_path_scores[n.metadata[\"paper_path\"]] += n.score\n", - " \n", + "\n", " # Sort paper_path_scores by score in descending order\n", - " sorted_paper_paths = sorted(paper_path_scores.items(), key=itemgetter(1), reverse=True)\n", + " sorted_paper_paths = sorted(\n", + " paper_path_scores.items(), key=itemgetter(1), reverse=True\n", + " )\n", " # Take top_n paper paths\n", " top_paper_paths = [path for path, score in sorted_paper_paths[:top_n]]\n", - " \n", + "\n", " # use summary index to get nodes from all paper paths\n", " all_nodes = []\n", " for paper_path in top_paper_paths:\n", - " # NOTE: input to retriever can be blank \n", - " all_nodes.extend(paper_summary_indexes[Path(paper_path).name].as_retriever().retrieve(\"\"))\n", - " \n", + " # NOTE: input to retriever can be blank\n", + " all_nodes.extend(\n", + " paper_summary_indexes[Path(paper_path).name].as_retriever().retrieve(\"\")\n", + " )\n", + "\n", " return all_nodes\n", "\n", + "\n", "def doc_retriever_fn(query: str) -> float:\n", " \"\"\"Document retriever that retrieves entire documents from the corpus.\n", - " \n", + "\n", " ONLY use for research questions that may require searching over entire research reports.\n", - " \n", + "\n", " Will be slower and more expensive than chunk-level retrieval but may be necessary.\n", " \"\"\"\n", " retriever = index.as_retriever(similarity_top_k=5)\n", @@ -561,11 +561,9 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "id": "db8f9cf2-3da2-4174-9981-36283a1f0350", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "from llama_index.llms.openai import OpenAI\n", @@ -627,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "id": "7fa654a0-4c43-48e3-85d1-b1796445fe35", "metadata": {}, "outputs": [], @@ -651,15 +649,15 @@ "\n", "class InputEvent(Event):\n", " input: List[ChatMessage]\n", - " \n", + "\n", "\n", "class ChunkRetrievalEvent(Event):\n", " tool_call: ToolSelection\n", - " \n", - " \n", + "\n", + "\n", "class DocRetrievalEvent(Event):\n", " tool_call: ToolSelection\n", - " \n", + "\n", "\n", "class ReportGenerationEvent(Event):\n", " pass\n", @@ -679,12 +677,14 @@ " super().__init__(**kwargs)\n", " self.chunk_retriever_tool = chunk_retriever_tool\n", " self.doc_retriever_tool = doc_retriever_tool\n", - " \n", + "\n", " self.llm = llm or OpenAI()\n", " self.summarizer = CompactAndRefine(llm=self.llm)\n", " assert self.llm.metadata.is_function_calling_model\n", - " \n", - " self.report_gen_sllm = report_gen_sllm or self.llm.as_structured_llm(ReportOutput, system_prompt=report_gen_system_prompt)\n", + "\n", + " self.report_gen_sllm = report_gen_sllm or self.llm.as_structured_llm(\n", + " ReportOutput, system_prompt=report_gen_system_prompt\n", + " )\n", " self.report_gen_summarizer = CompactAndRefine(llm=self.report_gen_sllm)\n", "\n", " self.memory = ChatMemoryBuffer.from_defaults(llm=llm)\n", @@ -694,7 +694,7 @@ " async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:\n", " # clear sources\n", " self.sources = []\n", - " \n", + "\n", " ctx.data[\"stored_chunks\"] = []\n", " ctx.data[\"query\"] = ev.input\n", "\n", @@ -714,7 +714,8 @@ " chat_history = ev.input\n", "\n", " response = await self.llm.achat_with_tools(\n", - " [self.chunk_retriever_tool, self.doc_retriever_tool], chat_history=chat_history\n", + " [self.chunk_retriever_tool, self.doc_retriever_tool],\n", + " chat_history=chat_history,\n", " )\n", " self.memory.put(response.message)\n", "\n", @@ -751,10 +752,16 @@ "\n", " # synthesize an answer given the query to return to the LLM.\n", " response = self.summarizer.synthesize(query, nodes=retrieved_chunks)\n", - " self.memory.put(ChatMessage(role=\"tool\", content=str(response), additional_kwargs={\n", - " \"tool_call_id\": ev.tool_call.tool_id,\n", - " \"name\": ev.tool_call.tool_name\n", - " }))\n", + " self.memory.put(\n", + " ChatMessage(\n", + " role=\"tool\",\n", + " content=str(response),\n", + " additional_kwargs={\n", + " \"tool_call_id\": ev.tool_call.tool_id,\n", + " \"name\": ev.tool_call.tool_name,\n", + " },\n", + " )\n", + " )\n", "\n", " # send input event back with updated chat history\n", " return InputEvent(input=self.memory.get())\n", @@ -765,18 +772,18 @@ " ) -> StopEvent:\n", " \"\"\"Generate report.\"\"\"\n", " # given all the context, generate query\n", - " response = self.report_gen_summarizer.synthesize(ctx.data[\"query\"], nodes=ctx.data[\"stored_chunks\"])\n", - " \n", + " response = self.report_gen_summarizer.synthesize(\n", + " ctx.data[\"query\"], nodes=ctx.data[\"stored_chunks\"]\n", + " )\n", + "\n", " return StopEvent(result={\"response\": response})" ] }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "id": "585893ad-5bb8-493b-93f9-ce4c179d06f6", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "agent = ReportGenerationAgent(\n", @@ -785,17 +792,15 @@ " llm=llm,\n", " report_gen_sllm=report_gen_sllm,\n", " verbose=True,\n", - " timeout=60.0\n", + " timeout=60.0,\n", ")" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "id": "3c33b3bd-fc26-4538-ac30-68509f570833", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -822,11 +827,9 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "id": "0642f98d-d071-43dc-89e8-4dfa1d628b43", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "data": { @@ -1018,11 +1021,9 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "id": "96b8f6c6-d47b-42e4-8741-5e786124c10d", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1048,17 +1049,15 @@ "source": [ "ret = await agent.run(\n", " input=\"Help me generate a report comparing LongLoRA vs. LoftQ. \"\n", - " \"What are similarities/differences in terms of techniques and experimental results?\" \n", + " \"What are similarities/differences in terms of techniques and experimental results?\"\n", ")" ] }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "id": "15c14174-9c8d-4ed2-a69a-dd266079bc3a", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "data": { @@ -1191,11 +1190,9 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "id": "4169197d-f5da-4d4d-bc3f-82f20e3260a9", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1210,14 +1207,6 @@ "\n", "draw_most_recent_execution(agent)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e311889-2a49-4771-8761-fe3f8424934a", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -1235,8 +1224,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" + "pygments_lexer": "ipython3" } }, "nbformat": 4,