Skip to content

Commit

Permalink
Merge pull request #264 from tigergraph/GML-1852-performant-doc-inges…
Browse files Browse the repository at this point in the history
…tion-pipeline-in-support-ai

Gml 1852 performant doc ingestion pipeline in support ai
  • Loading branch information
luzhoutg authored Aug 20, 2024
2 parents 1b8b7ba + 1c1a893 commit da05a4a
Show file tree
Hide file tree
Showing 10 changed files with 930 additions and 264 deletions.
287 changes: 35 additions & 252 deletions copilot/docs/notebooks/SupportAIDemo.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ services:

eventual-consistency-service:
image: tigergraphml/ecc:latest
# container_name: eventual-consistency-service
container_name: eventual-consistency-service
build:
context: .
dockerfile: eventual-consistency-service/Dockerfile
Expand Down
2 changes: 1 addition & 1 deletion eventual-consistency-service/app/graphrag/graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def stream_docs(
# continue to the next doc.
# This doc will not be marked as processed, so the ecc will process it eventually.
continue
logger.info("steam_docs writes to docs")
logger.info("stream_docs writes to docs")
await docs_chan.put(res.json()["results"][0]["DocContent"][0])
except Exception as e:
exc = traceback.format_exc()
Expand Down
2 changes: 2 additions & 0 deletions eventual-consistency-service/app/graphrag/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

tg_sem = asyncio.Semaphore(100)

tg_sem = asyncio.Semaphore(100)

async def install_queries(
requried_queries: list[str],
Expand Down Expand Up @@ -211,6 +212,7 @@ async def upsert_vertex(
res.raise_for_status()



async def check_vertex_exists(conn, v_id: str):
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
Expand Down
14 changes: 4 additions & 10 deletions eventual-consistency-service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

os.environ["ECC"] = "true"
import json
import time
import logging
from contextlib import asynccontextmanager
from threading import Thread
from typing import Annotated, Callable

import ecc_util
import graphrag
import supportai
from eventual_consistency_checker import EventualConsistencyChecker
from fastapi import BackgroundTasks, Depends, FastAPI, Response, status
from fastapi.security.http import HTTPBase
Expand Down Expand Up @@ -179,19 +181,11 @@ def consistency_status(
)
match ecc_method:
case SupportAIMethod.SUPPORTAI:
if graphname in consistency_checkers:
ecc = consistency_checkers[graphname]
ecc_status = json.dumps(ecc.get_status())
else:
start_ecc_in_thread(graphname, conn)
ecc_status = (
f"Eventual consistency checker started for graph {graphname}"
)
background.add_task(supportai.run, graphname, conn)

LogWriter.info(f"Returning consistency status for {graphname}: {status}")
ecc_status = f"SupportAI initialization on {graphname} {time.ctime()}"
case SupportAIMethod.GRAPHRAG:
background.add_task(graphrag.run, graphname, conn)
import time

ecc_status = f"GraphRAG initialization on {conn.graphname} {time.ctime()}"
case _:
Expand Down
1 change: 1 addition & 0 deletions eventual-consistency-service/app/supportai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .supportai_init import *
213 changes: 213 additions & 0 deletions eventual-consistency-service/app/supportai/supportai_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import asyncio
import logging
import time
import traceback
import httpx

from aiochannel import Channel
from pyTigerGraph import TigerGraphConnection

from common.config import embedding_service
from common.embeddings.milvus_embedding_store import MilvusEmbeddingStore
from common.extractors.BaseExtractor import BaseExtractor
from supportai import workers
from supportai.util import (
init,
make_headers,
http_timeout,
stream_ids,
tg_sem
)

logger = logging.getLogger(__name__)

consistency_checkers = {}


async def stream_docs(
conn: TigerGraphConnection,
docs_chan: Channel,
ttl_batches: int = 10
):
"""
Streams the document contents into the docs_chan
"""
logger.info("streaming docs")
headers = make_headers(conn)
async with httpx.AsyncClient(timeout=http_timeout) as client:
for i in range(ttl_batches):
doc_ids = await stream_ids(conn, "Document", i, ttl_batches)
if doc_ids["error"]:
continue

for d in doc_ids["ids"]:
try:
async with tg_sem:
res = await client.get(
f"{conn.restppUrl}/query/{conn.graphname}/StreamDocContent/",
params={"doc": d},
headers=headers,
)
if res.status_code != 200:
continue
logger.info("stream_docs writes to docs")
await docs_chan.put(res.json()["results"][0]["DocContent"][0])
except Exception as e:
exc = traceback.format_exc()
logger.error(f"Error retrieveing doc: {d} --> {e}\n{exc}")
continue
logger.info("stream_docs done")
logger.info("closing docs chan")
docs_chan.close()


async def chunk_docs(
conn: TigerGraphConnection,
docs_chan: Channel,
embed_chan: Channel,
upsert_chan: Channel,
extract_chan: Channel
):
"""
Creates and starts one worker for each document
in the docs channel.
"""
logger.info("Reading form docs channel")
doc_task = []
async with asyncio.TaskGroup() as sp:
async for content in docs_chan:
# v_id = content["v_id"]
# txt = content["attributes"]["text"]

logger.info("chunk writes to extract")
# await embed_chan.put((v_id, txt, "Document"))

task = sp.create_task(
workers.chunk_doc(conn, content, upsert_chan, embed_chan, extract_chan)
)
doc_task.append(task)

logger.info("chunk_docs done")
logger.info("closing extract_chan")
extract_chan.close()


async def upsert(
upsert_chan: Channel
):
"""
Creates and starts one worker for each upsert job
chan expects:
(func, args) <- q.get()
"""

logger.info("Reading from upsert channel")
# consume task queue
async with asyncio.TaskGroup() as sp:
async for func, args in upsert_chan:
logger.info(f"{func.__name__}, {args[1]}")
# execute the task
sp.create_task(func(*args))

logger.info(f"upsert done")


async def embed(
embed_chan: Channel,
index_stores: dict[str, MilvusEmbeddingStore],
graphname: str
):
"""
Creates and starts one worker for each embed job
chan expects:
(v_id, content, index_name) <- q.get()
"""
logger.info("Reading from embed channel")
async with asyncio.TaskGroup() as sp:
# consume task queue
async for v_id, content, index_name in embed_chan:
embedding_store = index_stores[f"{graphname}_{index_name}"]
logger.info(f"Embed to {graphname}_{index_name}: {v_id}")
sp.create_task(
workers.embed(
embedding_service,
embedding_store,
v_id,
content,
)
)

logger.info(f"embed done")


async def extract(
extract_chan: Channel,
upsert_chan: Channel,
embed_chan: Channel,
extractor: BaseExtractor,
conn: TigerGraphConnection
):
"""
Creates and starts one worker for each extract job
chan expects:
(chunk , chunk_id) <- q.get()
"""
logger.info("Reading from extract channel")
# consume task queue
async with asyncio.TaskGroup() as sp:
async for item in extract_chan:
sp.create_task(
workers.extract(upsert_chan, embed_chan, extractor, conn, *item)
)

logger.info(f"extract done")

logger.info("closing upsert and embed chan")
upsert_chan.close()
embed_chan.close()


async def run(
graphname: str,
conn: TigerGraphConnection,
upsert_limit=100
):
"""
Set up SupportAI:
- Install necessary queries.
- Process the documents into:
- chuncks
- embeddings
- entities/relationshio (and their embeddings)
- upsert everything to the graph
"""

extractor, index_stores = await init(conn)
init_start = time.perf_counter()

doc_process_switch = True

if doc_process_switch:
logger.info("Doc Processing Start")
docs_chan = Channel(1)
embed_chan = Channel(100)
upsert_chan = Channel(100)
extract_chan = Channel(100)
async with asyncio.TaskGroup() as sp:
# Get docs
sp.create_task(stream_docs(conn, docs_chan, 10))
# Process docs
sp.create_task(
chunk_docs(conn, docs_chan, embed_chan, upsert_chan, extract_chan)
)
# Upsert chunks
sp.create_task(upsert(upsert_chan))
# Embed
sp.create_task(embed(embed_chan, index_stores, graphname))
# Extract entities
sp.create_task(
extract(extract_chan, upsert_chan, embed_chan, extractor, conn)
)
init_end = time.perf_counter()
logger.info("Doc Processing End")
logger.info(f"DONE. supportai system initializer dT: {init_end-init_start}")
Loading

0 comments on commit da05a4a

Please sign in to comment.