Skip to content

Commit

Permalink
fixes embeddings calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Jan 7, 2025
1 parent bd6fecb commit 6a0a836
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/ai/src/api/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from functools import wraps, partial
from fastapi import FastAPI
from functools import wraps, partial
from fastapi.responses import JSONResponse
from fastapi.middleware.gzip import GZipMiddleware

Expand Down
13 changes: 5 additions & 8 deletions src/ai/src/grpc/server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from concurrent.futures import ThreadPoolExecutor
import grpc
import service_pb2
import service_pb2_grpc
from grpc_reflection.v1alpha import reflection
from concurrent.futures import ThreadPoolExecutor


import service_pb2
import service_pb2_grpc
from service_pb2 import (
OramaModel as ProtoOramaModel,
OramaIntent as ProtoOramaIntent,
Expand Down Expand Up @@ -33,10 +32,8 @@ def GetEmbedding(self, request, context):
def serve(config, embeddings_service):
print(f"Starting gRPC server on port {config.embeddings_grpc_port}")
server = grpc.server(ThreadPoolExecutor(max_workers=10))

service_pb2_grpc.add_CalculateEmbeddingsServiceServicer_to_server(
CalculateEmbeddingService(embeddings_service), server
)
service = CalculateEmbeddingService(embeddings_service)
service_pb2_grpc.add_CalculateEmbeddingsServiceServicer_to_server(service, server)

SERVICE_NAMES = (
service_pb2.DESCRIPTOR.services_by_name["CalculateEmbeddingsService"].full_name,
Expand Down
4 changes: 4 additions & 0 deletions src/ai/src/service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, config):
self.app = create_app(self)

def _initialize_embeddings_service(self):
import os

os.environ["ONNXRUNTIME_PROVIDERS"] = "CPUExecutionProvider"

extend_fastembed_supported_models()
initialize_thread_executor(max_workers=self.config.total_threads // 2)
return EmbeddingsModels(self.config, selected_models=ModelGroups[self.config.default_model_group].value)
Expand Down
1 change: 1 addition & 0 deletions src/ai/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
grpcurl -d '{ "model": "BGESmall", "input": ["hello, world!", "hey there", "foo bar"], "intent": "passage" }' -plaintext localhost:50051 orama_ai_service.CalculateEmbeddingsService/GetEmbedding

0 comments on commit 6a0a836

Please sign in to comment.