Skip to content

Commit

Permalink
Add Vertex embeddings to RAG package. (#33593)
Browse files Browse the repository at this point in the history
Co-authored-by: Claude <[email protected]>
  • Loading branch information
claudevdm and Claude authored Jan 15, 2025
1 parent 15f973f commit b5fa883
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Tests for apache_beam.ml.rag.embeddings.huggingface."""

import shutil
import tempfile
import unittest

Expand Down Expand Up @@ -73,6 +74,9 @@ def setUp(self):
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
expected = [
Chunk(
Expand Down
97 changes: 97 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Vertex AI Python SDK is required for this module.
# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long
# to install Vertex AI Python SDK.

"""RAG-specific embedding implementations using Vertex AI models."""

from typing import Optional

from google.auth.credentials import Credentials

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE
from apache_beam.ml.transforms.embeddings.vertex_ai import _VertexAITextEmbeddingHandler

try:
import vertexai
except ImportError:
vertexai = None


class VertexAITextEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
*,
title: Optional[str] = None,
task_type: str = DEFAULT_TASK_TYPE,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Credentials] = None,
**kwargs):
"""Utilizes Vertex AI text embeddings for semantic search and RAG
pipelines.
Args:
model_name: Name of the Vertex AI text embedding model
title: Optional title for the text content
task_type: Task type for embeddings (default: RETRIEVAL_DOCUMENT)
project: GCP project ID
location: GCP location
credentials: Optional GCP credentials
**kwargs: Additional arguments passed to EmbeddingsManager including
ModelHandler inference_args.
"""
if not vertexai:
raise ImportError(
"vertexai is required to use VertexAITextEmbeddings. "
"Please install it with `pip install google-cloud-aiplatform`")

super().__init__(type_adapter=create_rag_adapter(), **kwargs)
self.model_name = model_name
self.title = title
self.task_type = task_type
self.project = project
self.location = location
self.credentials = credentials

def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _VertexAITextEmbeddingHandler(
model_name=self.model_name,
title=self.title,
task_type=self.task_type,
project=self.project,
location=self.location,
credentials=self.credentials,
)

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
110 changes: 110 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for apache_beam.ml.rag.embeddings.vertex_ai."""

import shutil
import tempfile
import unittest

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

# pylint: disable=ungrouped-imports
try:
import vertexai # pylint: disable=unused-import
from apache_beam.ml.rag.embeddings.vertex_ai import VertexAITextEmbeddings
VERTEX_AI_AVAILABLE = True
except ImportError:
VERTEX_AI_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAITextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
# gecko@002 produces 768-dimensional embeddings
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = VertexAITextEmbeddings(model_name="textembedding-gecko@002")

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))


if __name__ == '__main__':
unittest.main()

0 comments on commit b5fa883

Please sign in to comment.