Skip to content

Commit

Permalink
Embeddings: minimal end-to-end searching with qdrant (sourcegraph#55772)
Browse files Browse the repository at this point in the history
  • Loading branch information
camdencheek authored Aug 15, 2023
1 parent 2036e42 commit c98cac6
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 16 deletions.
3 changes: 3 additions & 0 deletions enterprise/cmd/frontend/internal/context/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ go_library(
"//enterprise/cmd/frontend/internal/context/resolvers",
"//internal/codeintel",
"//internal/codycontext:context",
"//internal/conf",
"//internal/conf/conftypes",
"//internal/database",
"//internal/embeddings",
"//internal/embeddings/db",
"//internal/grpc/defaults",
"//internal/observation",
"//internal/search/client",
],
Expand Down
12 changes: 12 additions & 0 deletions enterprise/cmd/frontend/internal/context/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/cmd/frontend/internal/context/resolvers"
"github.com/sourcegraph/sourcegraph/internal/codeintel"
codycontext "github.com/sourcegraph/sourcegraph/internal/codycontext"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/embeddings"
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
"github.com/sourcegraph/sourcegraph/internal/grpc/defaults"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/client"
)
Expand All @@ -24,11 +27,20 @@ func Init(
) error {
embeddingsClient := embeddings.NewDefaultClient()
searchClient := client.New(observationCtx.Logger, db)
qdrantSearcher := vdb.NewDisabledDB()
if addr := conf.ServiceConnections().Qdrant; addr != "" {
conn, err := defaults.Dial(addr, observationCtx.Logger)
if err != nil {
return err
}
qdrantSearcher = vdb.NewQdrantDBFromConn(conn)
}
contextClient := codycontext.NewCodyContextClient(
observationCtx,
db,
embeddingsClient,
searchClient,
qdrantSearcher,
)
enterpriseServices.CodyContextResolver = resolvers.NewResolver(
db,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ func TestContextResolver(t *testing.T) {
db,
mockEmbeddingsClient,
mockSearchClient,
nil,
)

resolver := NewResolver(
Expand Down
2 changes: 1 addition & 1 deletion enterprise/cmd/worker/internal/embeddings/repo/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *repoEmbeddingJob) Routines(_ context.Context, observationCtx *observati
return nil, err
}

qdrantInserter := vdb.NewNoopInserter()
qdrantInserter := vdb.NewNoopDB()
if qdrantAddr := conf.Get().ServiceConnections().Qdrant; qdrantAddr != "" {
conn, err := defaults.Dial(qdrantAddr, observationCtx.Logger)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/codycontext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ go_library(
visibility = ["//:__subpackages__"],
deps = [
"//internal/api",
"//internal/conf",
"//internal/database",
"//internal/embeddings",
"//internal/embeddings/db",
"//internal/embeddings/embed",
"//internal/featureflag",
"//internal/metrics",
"//internal/observation",
"//internal/search",
Expand All @@ -18,6 +21,7 @@ go_library(
"//internal/search/result",
"//internal/search/streaming",
"//internal/types",
"//lib/errors",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
"@io_opentelemetry_go_otel//attribute",
Expand Down
62 changes: 61 additions & 1 deletion internal/codycontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (
"go.opentelemetry.io/otel/attribute"

"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/embeddings"
vdb "github.com/sourcegraph/sourcegraph/internal/embeddings/db"
"github.com/sourcegraph/sourcegraph/internal/embeddings/embed"
"github.com/sourcegraph/sourcegraph/internal/featureflag"
"github.com/sourcegraph/sourcegraph/internal/metrics"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search"
Expand All @@ -25,6 +28,7 @@ import (
"github.com/sourcegraph/sourcegraph/internal/search/result"
"github.com/sourcegraph/sourcegraph/internal/search/streaming"
"github.com/sourcegraph/sourcegraph/internal/types"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

type FileChunkContext struct {
Expand All @@ -36,7 +40,7 @@ type FileChunkContext struct {
EndLine int
}

func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient) *CodyContextClient {
func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embeddingsClient embeddings.Client, searchClient client.SearchClient, qdrantSearcher vdb.VectorSearcher) *CodyContextClient {
redMetrics := metrics.NewREDMetrics(
obsCtx.Registerer,
"codycontext_client",
Expand All @@ -58,6 +62,7 @@ func NewCodyContextClient(obsCtx *observation.Context, db database.DB, embedding
db: db,
embeddingsClient: embeddingsClient,
searchClient: searchClient,
qdrantSearcher: qdrantSearcher,

obsCtx: obsCtx,
getCodyContextOp: op("getCodyContext"),
Expand All @@ -70,6 +75,7 @@ type CodyContextClient struct {
db database.DB
embeddingsClient embeddings.Client
searchClient client.SearchClient
qdrantSearcher vdb.VectorSearcher

obsCtx *observation.Context
getCodyContextOp *observation.Operation
Expand All @@ -84,6 +90,14 @@ type GetContextArgs struct {
TextResultsCount int32
}

func (a *GetContextArgs) RepoIDs() []api.RepoID {
res := make([]api.RepoID, 0, len(a.Repos))
for _, repo := range a.Repos {
res = append(res, repo.ID)
}
return res
}

func (a *GetContextArgs) Attrs() []attribute.KeyValue {
return []attribute.KeyValue{
attribute.Int("numRepos", len(a.Repos)),
Expand Down Expand Up @@ -170,6 +184,10 @@ func (c *CodyContextClient) getEmbeddingsContext(ctx context.Context, args GetCo
return nil, nil
}

if featureflag.FromContext(ctx).GetBoolOr("qdrant", false) {
return c.getEmbeddingsContextFromQdrant(ctx, args)
}

repoNames := make([]api.RepoName, len(args.Repos))
repoIDs := make([]api.RepoID, len(args.Repos))
for i, repo := range args.Repos {
Expand Down Expand Up @@ -308,6 +326,48 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte
return append(results[0], results[1]...), nil
}

func (c *CodyContextClient) getEmbeddingsContextFromQdrant(ctx context.Context, args GetContextArgs) (_ []FileChunkContext, err error) {
embeddingsConf := conf.GetEmbeddingsConfig(conf.Get().SiteConfig())
if c == nil {
return nil, errors.New("embeddings not configured or disabled")
}
client, err := embed.NewEmbeddingsClient(embeddingsConf)
if err != nil {
return nil, errors.Wrap(err, "getting embeddings client")
}

resp, err := client.GetQueryEmbedding(ctx, args.Query)
if err != nil || len(resp.Failed) > 0 {
return nil, errors.Wrap(err, "getting query embedding")
}
query := resp.Embeddings

params := vdb.SearchParams{
ModelID: client.GetModelIdentifier(),
RepoIDs: args.RepoIDs(),
Query: query,
CodeLimit: int(args.CodeResultsCount),
TextLimit: int(args.TextResultsCount),
}
chunks, err := c.qdrantSearcher.Search(ctx, params)
if err != nil {
return nil, errors.Wrap(err, "searching vector DB")
}

res := make([]FileChunkContext, 0, len(chunks))
for _, chunk := range chunks {
res = append(res, FileChunkContext{
RepoName: chunk.Point.Payload.RepoName,
RepoID: chunk.Point.Payload.RepoID,
CommitID: chunk.Point.Payload.Revision,
Path: chunk.Point.Payload.FilePath,
StartLine: int(chunk.Point.Payload.StartLine),
EndLine: int(chunk.Point.Payload.EndLine),
})
}
return res, nil
}

func fileMatchToContextMatches(fm *result.FileMatch) []FileChunkContext {
if len(fm.ChunkMatches) == 0 {
return nil
Expand Down
45 changes: 36 additions & 9 deletions internal/embeddings/db/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,55 @@ import (
"context"

"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

func NewNoopInserter() VectorInserter {
return noopInserter{}
func NewNoopDB() VectorDB {
return noopDB{}
}

var _ VectorDB = noopInserter{}
var _ VectorDB = noopDB{}

type noopInserter struct{}
type noopDB struct{}

func (noopInserter) Search(context.Context, SearchParams) ([]ChunkResult, error) {
func (noopDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
return nil, nil
}
func (noopInserter) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
func (noopDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
return nil
}
func (noopInserter) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
func (noopDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
return false, nil
}
func (noopInserter) InsertChunks(context.Context, InsertParams) error {
func (noopDB) InsertChunks(context.Context, InsertParams) error {
return nil
}
func (noopInserter) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
func (noopDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
return nil
}

var ErrDisabled = errors.New("Qdrant is disabled. Enable by setting QDRANT_ENDPOINT")

func NewDisabledDB() VectorDB {
return disabledDB{}
}

var _ VectorDB = disabledDB{}

type disabledDB struct{}

func (disabledDB) Search(context.Context, SearchParams) ([]ChunkResult, error) {
return nil, ErrDisabled
}
func (disabledDB) PrepareUpdate(ctx context.Context, modelID string, modelDims uint64) error {
return ErrDisabled
}
func (disabledDB) HasIndex(ctx context.Context, modelID string, repoID api.RepoID, revision api.CommitID) (bool, error) {
return false, ErrDisabled
}
func (disabledDB) InsertChunks(context.Context, InsertParams) error {
return ErrDisabled
}
func (disabledDB) FinalizeUpdate(context.Context, FinalizeUpdateParams) error {
return ErrDisabled
}
18 changes: 15 additions & 3 deletions internal/embeddings/db/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,22 @@ type qdrantDB struct {
var _ VectorDB = (*qdrantDB)(nil)

type SearchParams struct {
ModelID string
RepoIDs []api.RepoID
Query []float32
// RepoIDs is the set of repos to search.
// If empty, all repos are searched.
RepoIDs []api.RepoID

// The ID of the model that the query was embedded with.
// Embeddings for other models will not be searched.
ModelID string

// Query is the embedding for the search query.
// Its dimensions must match the model dimensions.
Query []float32

// The maximum number of code results to return
CodeLimit int

// The maximum number of text results to return
TextLimit int
}

Expand Down
4 changes: 2 additions & 2 deletions internal/embeddings/embed/embed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestEmbedRepo(t *testing.T) {
}
revision := api.CommitID("deadbeef")
embeddingsClient := NewMockEmbeddingsClient()
inserter := db.NewNoopInserter()
inserter := db.NewNoopDB()
contextService := NewMockContextService()
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
Expand Down Expand Up @@ -385,7 +385,7 @@ func TestEmbedRepo_ExcludeChunkOnError(t *testing.T) {
repoIDName := types.RepoIDName{Name: repoName}
embeddingsClient := NewMockEmbeddingsClient()
contextService := NewMockContextService()
inserter := db.NewNoopInserter()
inserter := db.NewNoopDB()
contextService.SplitIntoEmbeddableChunksFunc.SetDefaultHook(defaultSplitter)
splitOptions := codeintelContext.SplitOptions{ChunkTokensThreshold: 8}
mockFiles := map[string][]byte{
Expand Down

0 comments on commit c98cac6

Please sign in to comment.