Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG Engine Enhancements #244

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 140 additions & 13 deletions rag/driver/qdrant/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ import (
"google.golang.org/grpc/keepalive"
)

// stringToUint64ID converts a string ID to uint64 using FNV-1a hash
func stringToUint64ID(s string) uint64 {
h := uint64(14695981039346656037) // FNV offset basis
for i := 0; i < len(s); i++ {
h ^= uint64(s[i])
h *= 1099511628211 // FNV prime
}
return h
}

// Engine implements the driver.Engine interface using Qdrant as the vector store backend
type Engine struct {
client *qdrant.Client
Expand Down Expand Up @@ -142,19 +152,43 @@ func (e *Engine) IndexDoc(ctx context.Context, indexName string, doc *driver.Doc
}

point := &qdrant.PointStruct{
Id: qdrant.NewID(doc.DocID),
Id: qdrant.NewIDNum(stringToUint64ID(doc.DocID)),
Vectors: qdrant.NewVectors(embeddings...),
Payload: map[string]*qdrant.Value{
"content": qdrant.NewValueString(doc.Content),
"content": qdrant.NewValueString(doc.Content),
"original_id": qdrant.NewValueString(doc.DocID),
},
}

if doc.Metadata != nil {
metadataStruct, err := qdrant.NewStruct(doc.Metadata)
if err != nil {
return fmt.Errorf("failed to convert metadata: %w", err)
payload := make(map[string]*qdrant.Value)
for k, v := range doc.Metadata {
switch val := v.(type) {
case string:
payload[k] = qdrant.NewValueString(val)
case float64:
payload[k] = qdrant.NewValueDouble(val)
case bool:
payload[k] = qdrant.NewValueBool(val)
case []string:
values := make([]*qdrant.Value, len(val))
for i, s := range val {
values[i] = qdrant.NewValueString(s)
}
payload[k] = &qdrant.Value{
Kind: &qdrant.Value_ListValue{
ListValue: &qdrant.ListValue{
Values: values,
},
},
}
case map[string]interface{}:
if nested, err := qdrant.NewStruct(val); err == nil {
payload[k] = qdrant.NewValueStruct(nested)
}
}
}
point.Payload["metadata"] = qdrant.NewValueStruct(metadataStruct)
point.Payload["metadata"] = qdrant.NewValueStruct(&qdrant.Struct{Fields: payload})
}

_, err = e.client.Upsert(ctx, &qdrant.UpsertPoints{
Expand Down Expand Up @@ -205,6 +239,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
results := make([]driver.SearchResult, len(points))
for i, point := range points {
content := point.Payload["content"].GetStringValue()
originalID := point.Payload["original_id"].GetStringValue()
var metadata map[string]interface{}
if metadataValue := point.Payload["metadata"]; metadataValue != nil {
if metadataStruct := metadataValue.GetStructValue(); metadataStruct != nil {
Expand All @@ -213,7 +248,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
}

results[i] = driver.SearchResult{
DocID: point.Id.GetUuid(),
DocID: originalID,
Score: float64(point.Score),
Content: content,
Metadata: metadata,
Expand All @@ -226,7 +261,7 @@ func (e *Engine) Search(ctx context.Context, indexName string, vector []float32,
func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string) (*driver.Document, error) {
points, err := e.client.Get(ctx, &qdrant.GetPoints{
CollectionName: indexName,
Ids: []*qdrant.PointId{qdrant.NewID(DocID)},
Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))},
WithPayload: qdrant.NewWithPayload(true),
WithVectors: qdrant.NewWithVectors(true),
})
Expand All @@ -243,6 +278,7 @@ func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string

point := points[0]
content := point.Payload["content"].GetStringValue()
originalID := point.Payload["original_id"].GetStringValue()
var metadata map[string]interface{}
if metadataValue := point.Payload["metadata"]; metadataValue != nil {
if metadataStruct := metadataValue.GetStructValue(); metadataStruct != nil {
Expand All @@ -251,7 +287,7 @@ func (e *Engine) GetDocument(ctx context.Context, indexName string, DocID string
}

return &driver.Document{
DocID: DocID,
DocID: originalID,
Content: content,
Metadata: metadata,
Embeddings: point.Vectors.GetVector().Data,
Expand Down Expand Up @@ -341,6 +377,23 @@ func convertStructToMap(s *qdrant.Struct) map[string]interface{} {
result[k] = x.DoubleValue
case *qdrant.Value_BoolValue:
result[k] = x.BoolValue
case *qdrant.Value_ListValue:
if x.ListValue != nil {
list := make([]interface{}, len(x.ListValue.Values))
for i, lv := range x.ListValue.Values {
switch lx := lv.Kind.(type) {
case *qdrant.Value_StringValue:
list[i] = lx.StringValue
case *qdrant.Value_DoubleValue:
list[i] = lx.DoubleValue
case *qdrant.Value_BoolValue:
list[i] = lx.BoolValue
case *qdrant.Value_StructValue:
list[i] = convertStructToMap(lx.StructValue)
}
}
result[k] = list
}
case *qdrant.Value_StructValue:
result[k] = convertStructToMap(x.StructValue)
}
Expand Down Expand Up @@ -380,10 +433,11 @@ func (e *Engine) IndexBatch(ctx context.Context, indexName string, docs []*drive
}

point := &qdrant.PointStruct{
Id: qdrant.NewID(doc.DocID),
Id: qdrant.NewIDNum(stringToUint64ID(doc.DocID)),
Vectors: qdrant.NewVectors(embeddings...),
Payload: map[string]*qdrant.Value{
"content": qdrant.NewValueString(doc.Content),
"content": qdrant.NewValueString(doc.Content),
"original_id": qdrant.NewValueString(doc.DocID),
},
}

Expand Down Expand Up @@ -421,7 +475,7 @@ func (e *Engine) DeleteDoc(ctx context.Context, indexName string, DocID string)
Points: &qdrant.PointsSelector{
PointsSelectorOneOf: &qdrant.PointsSelector_Points{
Points: &qdrant.PointsIdsList{
Ids: []*qdrant.PointId{qdrant.NewID(DocID)},
Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))},
},
},
},
Expand All @@ -447,7 +501,7 @@ func (e *Engine) DeleteBatch(ctx context.Context, indexName string, DocIDs []str

pointIDs := make([]*qdrant.PointId, len(DocIDs))
for i, id := range DocIDs {
pointIDs[i] = qdrant.NewID(id)
pointIDs[i] = qdrant.NewIDNum(stringToUint64ID(id))
}

_, err := e.client.Delete(ctx, &qdrant.DeletePoints{
Expand Down Expand Up @@ -537,3 +591,76 @@ func (e *Engine) checkContext(ctx context.Context) error {
return nil
}
}

// HasDocument checks if a document exists in the specified collection
func (e *Engine) HasDocument(ctx context.Context, indexName string, DocID string) (bool, error) {
if err := e.checkContext(ctx); err != nil {
return false, err
}

points, err := e.client.Get(ctx, &qdrant.GetPoints{
CollectionName: indexName,
Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))},
WithPayload: qdrant.NewWithPayload(false), // We don't need payload, just checking existence
})
if err != nil {
if strings.Contains(err.Error(), "doesn't exist") {
return false, nil // Collection doesn't exist
}
return false, fmt.Errorf("failed to check document existence: %w", err)
}

return len(points) > 0, nil
}

// HasIndex checks if a collection exists
func (e *Engine) HasIndex(ctx context.Context, name string) (bool, error) {
if err := e.checkContext(ctx); err != nil {
return false, err
}

collections, err := e.client.ListCollections(ctx)
if err != nil {
return false, fmt.Errorf("failed to list collections: %w", err)
}

for _, collection := range collections {
if collection == name {
return true, nil
}
}
return false, nil
}

// GetMetadata retrieves only the metadata of a document by its ID from the specified collection
func (e *Engine) GetMetadata(ctx context.Context, indexName string, DocID string) (map[string]interface{}, error) {
if err := e.checkContext(ctx); err != nil {
return nil, err
}

points, err := e.client.Get(ctx, &qdrant.GetPoints{
CollectionName: indexName,
Ids: []*qdrant.PointId{qdrant.NewIDNum(stringToUint64ID(DocID))},
WithPayload: qdrant.NewWithPayload(true),
WithVectors: qdrant.NewWithVectors(false), // Don't fetch vectors to save memory
})
if err != nil {
if strings.Contains(err.Error(), "doesn't exist") {
return nil, fmt.Errorf("collection doesn't exist: %w", err)
}
return nil, fmt.Errorf("failed to get document metadata: %w", err)
}

if len(points) == 0 {
return nil, fmt.Errorf("document not found")
}

point := points[0]
if metadataValue := point.Payload["metadata"]; metadataValue != nil {
if metadataStruct := metadataValue.GetStructValue(); metadataStruct != nil {
return convertStructToMap(metadataStruct), nil
}
}

return make(map[string]interface{}), nil
}
Loading
Loading