Skip to content

Commit

Permalink
adding rag to cli v1
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterXY89 committed Jan 14, 2024
1 parent e934e05 commit 5c47ab1
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import argparse

import chat_doc.rag.main as rag
from chat_doc.app.app import App
from chat_doc.config import logger
from chat_doc.dataset_generation.dataset_factory import DatasetFactory
Expand All @@ -21,19 +22,16 @@
# Subparsers for "generate" and "train" commands
subparsers = parser.add_subparsers(dest="command")

# "generate" subcommand
# GENERATE subcommand
generate_parser = subparsers.add_parser("generate", help="Generate data")
generate_parser.add_argument(
"--dataset",
choices=DatasetFactory.available_datasets,
required=True,
help="Dataset to generate",
)
# generate_parser.add_argument(
# "--output_path", default="./data", help="Output path (default: ./data)"
# )

# "train" subcommand
# TRAIN subcommand
train_parser = subparsers.add_parser("train", help="Train the model")
train_parser.add_argument(
"--dataset",
Expand All @@ -56,13 +54,18 @@
)
train_parser.add_argument("--batch_size", type=int, help="Batch size (default: 2)", default=2)

# "run-app" subcommand
# RUN-APP subcommand
generate_parser = subparsers.add_parser("run-app", help="Run web-app")
generate_parser.add_argument("--port", default=5000, help="Port for the flask app")
generate_parser.add_argument(
"--debug", default=True, type=bool, help="Log-level for the flask app"
)

# RAG subcommand
rag_parser = subparsers.add_parser("run-rag", help="Run RAG")
rag_parser.add_argument("--query", required=True, help="Query string")
rag_parser.add_argument("--use_llm", default=False, help="Use LLM for augmented generation")

args = parser.parse_args()

if args.command == "generate":
Expand Down Expand Up @@ -96,7 +99,12 @@
app = App()
app.run(port=args.port, debug=args.debug)

elif args.command == "run-rag":
logger.info(f"Running RAG on query: {args.query}")
logger.info(f"Use LLM for augmented generation: {args.use_llm}")
rag.retrieve(args.query, args.use_llm)

else:
logger.error("Invalid command. Use 'generate' or 'train' or 'run-app'.")
logger.error("Invalid command. Use 'generate' or 'train', 'run-app' or 'run-rag'.")

args = parser.parse_args()

0 comments on commit 5c47ab1

Please sign in to comment.