diff --git a/.github/workflows/nightly-test.yml b/.github/workflows/nightly-test.yml
index a0571d1..2816c92 100644
--- a/.github/workflows/nightly-test.yml
+++ b/.github/workflows/nightly-test.yml
@@ -16,12 +16,16 @@ jobs:
runs-on: ubuntu-latest
outputs:
notebooks: ${{ steps.get_nbs.outputs.notebooks }}
+ has_notebooks: ${{ steps.get_nbs.outputs.has_notebooks }}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- id: get_nbs
run: |
- # 1) Read ignore patterns from .github/ignore-notebooks.txt
+ # 1) Find all available notebooks
+ NBS=$(find python-recipes -name '*.ipynb')
+
+ # 2) Load notebooks to ignore
IGNORE_LIST=()
while IFS= read -r skip_nb || [ -n "$skip_nb" ]; do
# Skip empty lines or comment lines
@@ -29,9 +33,6 @@ jobs:
IGNORE_LIST+=("$skip_nb")
done < .github/ignore-notebooks.txt
- # 2) Find all .ipynb in python-recipes (or your path)
- NBS=$(find python-recipes -name '*.ipynb')
-
# 3) Filter out notebooks that match anything in IGNORE_LIST
FILTERED_NBS=()
for nb in $NBS; do
@@ -42,29 +43,36 @@ jobs:
break
fi
done
-
if [ "$skip" = false ]; then
FILTERED_NBS+=("$nb")
fi
done
- # 4) Convert the final array to compact JSON for GitHub Actions
+ # 4) Stuff into a single-line JSON array
NB_JSON=$(printf '%s\n' "${FILTERED_NBS[@]}" \
| jq -R . \
| jq -s -c .)
- # 5) Default to an empty array if there's nothing left
if [ -z "$NB_JSON" ] || [ "$NB_JSON" = "[]" ]; then
NB_JSON="[]"
fi
echo "All valid notebooks: $NB_JSON"
+
+ # 5) Check if there's anything in FILTERED_NBS
+ if [ "${#FILTERED_NBS[@]}" -gt 0 ]; then
+ echo "has_notebooks=true" >> $GITHUB_OUTPUT
+ else
+ echo "has_notebooks=false" >> $GITHUB_OUTPUT
+ fi
+
echo "notebooks=$NB_JSON" >> $GITHUB_OUTPUT
# ---------------------------------------------------------
# 2) Test all notebooks in parallel
# ---------------------------------------------------------
test_all_notebooks:
+ if: ${{ needs.gather_all_notebooks.outputs.has_notebooks == 'true' }}
needs: gather_all_notebooks
runs-on: ubuntu-latest
strategy:
@@ -79,7 +87,7 @@ jobs:
- 6379:6379
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
# Setup Python
- uses: actions/setup-python@v4
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index d93f897..17af1c4 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -17,16 +17,18 @@ jobs:
runs-on: ubuntu-latest
outputs:
notebooks: ${{ steps.get_nbs.outputs.notebooks }}
+ has_notebooks: ${{ steps.get_nbs.outputs.has_notebooks }}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- - id: get_nbs
+ - name: Gather notebooks
+ id: get_nbs
run: |
- # Compare this commit/PR to 'main' and list changed .ipynb files
+ # 1) Compare this commit/PR to 'main' and list changed notebooks
git fetch --depth=1 origin main
CHANGED_NOTEBOOKS=$(git diff --name-only origin/main | grep '\.ipynb$' || true)
- # 1) Read ignore patterns from .github/ignore-notebooks.txt
+ # 2) Load notebooks to ignore
IGNORE_LIST=()
while IFS= read -r skip_nb || [ -n "$skip_nb" ]; do
# Skip empty lines or comment lines
@@ -34,11 +36,10 @@ jobs:
IGNORE_LIST+=("$skip_nb")
done < .github/ignore-notebooks.txt
- # 2) Filter out notebooks in CHANGED_NOTEBOOKS that match ignore patterns
+ # 3) Filter out ignored notebooks
FILTERED_NBS=()
for nb in $CHANGED_NOTEBOOKS; do
skip=false
-
# Check if in ignore list
for ignore_nb in "${IGNORE_LIST[@]}"; do
# Partial match:
@@ -47,33 +48,31 @@ jobs:
break
fi
done
-
if [ "$skip" = false ]; then
FILTERED_NBS+=("$nb")
fi
done
- # 3) Build a single-line JSON array
+ # 4) Stuff into a single-line JSON array
NB_JSON=$(printf '%s\n' "${FILTERED_NBS[@]}" \
| jq -R . \
| jq -s -c .)
- # 4) Fallback to an empty array if there's nothing left
if [ -z "$NB_JSON" ] || [ "$NB_JSON" = "[]" ]; then
NB_JSON="[]"
fi
echo "All valid notebooks: $NB_JSON"
- # 5) Write to $GITHUB_OUTPUT
- if [ "$NB_JSON" != "[]" ]; then
+ # 5) Check if there's anything in FILTERED_NBS
+ if [ "${#FILTERED_NBS[@]}" -gt 0 ]; then
echo "has_notebooks=true" >> $GITHUB_OUTPUT
else
echo "has_notebooks=false" >> $GITHUB_OUTPUT
fi
echo "notebooks=$NB_JSON" >> $GITHUB_OUTPUT
-
+
# ---------------------------------------------------------
# 2) Test each changed notebook in parallel
# ---------------------------------------------------------
@@ -93,7 +92,7 @@ jobs:
- 6379:6379
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
# Setup Python
- uses: actions/setup-python@v4
diff --git a/.gitignore b/.gitignore
index 1a701dd..5dced20 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,222 @@
-.env
-node_modules/
+# Created by https://www.toptal.com/developers/gitignore/api/python,venv,macos
+# Edit at https://www.toptal.com/developers/gitignore?templates=python,venv,macos
+
+### macOS ###
+# General
.DS_Store
-.pytest_cache/
\ No newline at end of file
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### macOS Patch ###
+# iCloud generated files
+*.icloud
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# LSP config files
+pyrightconfig.json
+
+### venv ###
+# Virtualenv
+# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
+[Bb]in
+[Ii]nclude
+[Ll]ib
+[Ll]ib64
+[Ll]ocal
+pyvenv.cfg
+pip-selfcheck.json
+
+libs/redis/docs/.Trash*
+.python-version
+.idea/*
diff --git a/.python-version b/.python-version
index 2419ad5..2c07333 100644
--- a/.python-version
+++ b/.python-version
@@ -1 +1 @@
-3.11.9
+3.11
diff --git a/python-recipes/finetuning/00_text_finetuning.ipynb b/python-recipes/finetuning/00_text_finetuning.ipynb
new file mode 100644
index 0000000..224df6f
--- /dev/null
+++ b/python-recipes/finetuning/00_text_finetuning.ipynb
@@ -0,0 +1,741 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fine tuning text embedding models using sentence_transformers\n",
+ "\n",
+ "If you're building an LLM application your system will likely include a text embedding model that transforms written text into vector embeddings. These may be used for classification, routing, document retrieval, semantic caching or search.\n",
+ "\n",
+ "One of the key measure of an embedding model is how well it can group semantically equivalent statements together, and similarly, how well it an distinguish between similar, but not equivalent statements.\n",
+ "\n",
+ "Because embedding models are not performing logical reasoning, but instead are often used to perform vector similarity calculations, we're not guaranteed that every pair of similar vectors will be relevant or equivalent, or that embeddings that are far apart in vector space aren't relevant to each other. This is why using the correct text embedding model is critical. Using a text embedding model specifically fine tuned to correctly match queries for your system can improve your overall app performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This notebook uses the [sentence_transformers](https://sbert.net/) library to fine tune a text embedding model on a custom dataset.\n",
+ "The training method used is [contrastive fine tuning](https://arxiv.org/abs/2408.00690), where two statements are assigned a label as either being similar {label=1.0} or dissimilar {label=0.0}.\n",
+ "Training then proceeds to minimize the cosine distance between similar statements, and maximize the cosine distance between dissimilar statements.\n",
+ "\n",
+ "This contrastive loss function is well suited to applications where we care about the metrics true positive, true negative, false positive, and false negative.\n",
+ "\n",
+ "## Let's Begin!\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n",
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --quiet torch datasets sentence_transformers 'transformers[torch]' redisvl matplotlib seaborn scikit-learn ipywidgets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Select our starting model and dataset to fine tune on\n",
+ "To perform finetuning you'll need a dataset that ideally is specific to your use case. For the type of training we'll be doing - contrastive fine tuning - you'll need to structure your dataset as a set of pairs of questions or statements and coresponding label indicating if they're equivalent or not.\n",
+ "\n",
+ "An example of what this looks like is in `sample_dataset.csv`\n",
+ "\n",
+ "| question_1 | question_2 | label |\n",
+ "|------------|------------|-------|\n",
+ "| What is AI? | What is artificial intelligence? | 1.0 |\n",
+ "| How to bake a cake? | How to make a sandwich? | 0.0 |\n",
+ "| Define machine learning. | Explain machine learning. | 1.0 |"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# select the datasets to train and test on\n",
+ "# we've provided examples in the datasets directory of our public S3 bucket for what these files should look like\n",
+ "train_data = 'sample_dataset.csv'\n",
+ "test_data = 'sample_testset.csv'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import requests\n",
+ "import os\n",
+ "\n",
+ "if not (os.path.exists(f\"./datasets/{train_data}\") and os.path.exists(f\"./datasets/{test_data}\")):\n",
+ " if not os.path.exists('./datasets/'):\n",
+ " os.mkdir('./datasets/')\n",
+ "\n",
+ " # download the files and save them locally\n",
+ " for file in [train_data, test_data]:\n",
+ " url = f'https://redis-ai-resources.s3.us-east-2.amazonaws.com/finetuning/datasets/{file}'\n",
+ " r = requests.get(url)\n",
+ " with open(f'./datasets/{file}', 'wb') as f:\n",
+ " f.write(r.content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "from sentence_transformers import SentenceTransformer\n",
+ "from sentence_transformers.losses import ContrastiveLoss\n",
+ "import copy\n",
+ "\n",
+ "# load a model to train/finetune\n",
+ "model_name = 'sentence-transformers/all-MiniLM-L6-v2'\n",
+ "\n",
+ "model = SentenceTransformer(model_name)\n",
+ "\n",
+ "# make a copy of the weights before training if we want to compare how much they've changed\n",
+ "before_training = copy.deepcopy(model.state_dict())\n",
+ "\n",
+ "# this loss requires pairs of text and a floating point similarity score as a label\n",
+ "# we'll use 'hard labels' of 1.0 or 0.0 as that is shown to lead to the best separation\n",
+ "loss = ContrastiveLoss(model)\n",
+ "\n",
+ "# load an example training dataset that works with our loss function:\n",
+ "train_dataset = load_dataset(\"csv\", data_files=f\"datasets/{train_data}\", split='train')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define our training arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
+ "from sentence_transformers.training_args import BatchSamplers\n",
+ "\n",
+ "args = SentenceTransformerTrainingArguments(\n",
+ " # required parameters\n",
+ " output_dir=f\"models/trained_on_{train_data}\",\n",
+ " # optional training parameters\n",
+ " num_train_epochs=1,\n",
+ " per_device_train_batch_size=16,\n",
+ " per_device_eval_batch_size=16,\n",
+ " warmup_ratio=0.1,\n",
+ " fp16=False, # set to False if your GPU can't handle FP16\n",
+ " bf16=False, # set to True if your GPU supports BF16\n",
+ " batch_sampler=BatchSamplers.NO_DUPLICATES, # losses using \"in-batch negatives\" benefit from no duplicates\n",
+ " # optional tracking/debugging parameters\n",
+ " eval_strategy=\"steps\",\n",
+ " eval_steps=100,\n",
+ " save_strategy=\"steps\",\n",
+ " save_steps=100,\n",
+ " save_total_limit=2,\n",
+ " logging_steps=100,\n",
+ " run_name=f\"model-base-{train_data}\", # used in Weights & Biases if `wandb` is installed\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Split your dataset to perform training validation\n",
+ "While our model is training both the training loss and validation loss will be recorded. These are printed to `stdout`, and also logged in\n",
+ "`models/model-base-all/checkpoint-/trainer_state.json`.\n",
+ "\n",
+ "sentence_transformers uses the term 'evaluation' rather than 'validation'."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "train Dataset({\n",
+ " features: ['question_1', 'question_2', 'label'],\n",
+ " num_rows: 41\n",
+ "})\n",
+ "validation Dataset({\n",
+ " features: ['question_1', 'question_2', 'label'],\n",
+ " num_rows: 11\n",
+ "})\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction\n",
+ "\n",
+ "# split the dataset into training and validation sets\n",
+ "train_dataset = train_dataset.train_test_split(train_size=0.8, seed=42)\n",
+ "\n",
+ "validation_dataset = train_dataset['test']\n",
+ "train_dataset = train_dataset['train']\n",
+ "\n",
+ "print('train', train_dataset)\n",
+ "print('validation', validation_dataset)\n",
+ "\n",
+ "# initialize the evaluator\n",
+ "dev_evaluator = EmbeddingSimilarityEvaluator(\n",
+ " sentences1=validation_dataset[\"question_1\"],\n",
+ " sentences2=validation_dataset[\"question_2\"],\n",
+ " scores=validation_dataset[\"label\"],\n",
+ " main_similarity=SimilarityFunction.COSINE,\n",
+ " name=f\"{train_data}-dev\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train our model\n",
+ "This cell performs the full training for the number of epochs defined in our `SentenceTransformerTrainingArguments`, args. Losses are periodically printed out."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "aafb575008c049f391e1d074a59e91dd",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8c4ce9fb5ce24bbc80fae2f63dc0950c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing widget examples: 0%| | 0/1 [00:00, ?example/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'train_runtime': 2.2703, 'train_samples_per_second': 18.06, 'train_steps_per_second': 1.321, 'train_loss': 0.03720299402872721, 'epoch': 1.0}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sentence_transformers import SentenceTransformerTrainer\n",
+ "\n",
+ "trainer = SentenceTransformerTrainer(\n",
+ " model=model,\n",
+ " args=args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=validation_dataset,\n",
+ " loss=loss,\n",
+ " evaluator=dev_evaluator,\n",
+ ")\n",
+ "trainer.train()\n",
+ "\n",
+ "# make a copy of the weights after training\n",
+ "after_training = copy.deepcopy(model.state_dict())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## That's it!\n",
+ "That is all it takes to produce a finetuned model on your dataset. Every application is different and you'll want to know how well this model can do with your system, and how much better it is now that you've tuned it. Here we compute some metrics to see the impact of fine tuning.\n",
+ "These will also help you choose the best similarity threshold for your app based on your criteria."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluate the trained model\n",
+ "As the ultimate goal of this fine tuning is to improve performance for distinguishing equivalent and not equivalent statements, we'll evaluate our trained model to to see how well it correctly identifies pairs of statements as either equivalent or not. We have a labeled dataset for this, where statements are known to be either equivalent, and labelled `1.0`, or not equivalent and labelled `0.0`. If you have a background in data science or statistics you may recognize this as a binary classification problem. Pairs embeddings will be classified as similar, aka positive, or dissimilar, aka negative. Comparing to our known labels that means that every pair of interest will in one of four groups:\n",
+ "- true positives (TP): our model correctly determines two vector embeddings are equivalent\n",
+ "- true negatives (TN): our model correctly determines two vector embeddings are different\n",
+ "- false positives (FP): our model incorrectly determines two vector embeddings are equivalent\n",
+ "- false negatives (FN): our model incorrectly determines two vector embeddings are different\n",
+ "\n",
+ "Here we can also evaluate on different datasets we haven't trained on. We've again provided a `sample_testset.csv` file to illustrate this, but you should replace this with data relevant to your app."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "def evaluate_model(model, test_dataset):\n",
+ " q1_embeddings = [model.encode(pair['question_1']) for pair in test_dataset]\n",
+ " q2_embeddings = [model.encode(pair['question_2']) for pair in test_dataset]\n",
+ " labels = [pair['label'] for pair in test_dataset]\n",
+ "\n",
+ " # compute all the distances between all the vectors for quicker reference later\n",
+ " distances = np.empty(shape=(len(q1_embeddings), len(q2_embeddings)), dtype=np.float32, order='C')\n",
+ " for index_1, embedding_1 in enumerate(q1_embeddings):\n",
+ " for index_2, embedding_2 in enumerate(q2_embeddings):\n",
+ " # compute cosine distance between embeddings\n",
+ " cosine_distance = 1 - np.dot(embedding_1, embedding_2) / (np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2))\n",
+ " distances[index_1, index_2] = cosine_distance\n",
+ "\n",
+ " # for our range of thresholds see which embeddings fall within our threshold and so would be consindered equivalent\n",
+ " metrics = {}\n",
+ " thresholds = np.linspace(0.01, 0.6, 60)\n",
+ " for threshold in thresholds:\n",
+ " TP = 0\n",
+ " FP = 0\n",
+ " TN = 0\n",
+ " FN = 0\n",
+ " for index, label in enumerate(labels):\n",
+ " # for question N find the most similar embedding, aka the one with the lowest distance\n",
+ " distance_of_nearest = np.min(distances[index, :])\n",
+ " index_of_nearest = np.argmin(distances[index, :])\n",
+ " if distances[index, :][index_of_nearest] <= threshold: # if the distance is below our threshold our model thinks they're equivalent\n",
+ " if label == 1: # check the label to see if it really is equivalent (label == 1) or if they're actually different (label == 0)\n",
+ " if index_of_nearest == index: # verify that we hit the correct matched pair, and not some other question\n",
+ " TP += 1 # we correctly found a matching entry\n",
+ " else:\n",
+ " FP += 1 # we found something we think is equivalent, but it's not what we should have found\n",
+ " else:\n",
+ " FP += 1 # we think we found an equivalent statement, but shouldn't have\n",
+ " else: # we didn't find anything\n",
+ " if label == 1: # check it should be a miss\n",
+ " FN += 1 # we failed to find a matching pair\n",
+ " else:\n",
+ " TN += 1 # correctly did not match any other embeddings\n",
+ "\n",
+ " F1 = (2 * TP) / (2 * TP + FP + FN)\n",
+ " accuracy = (TP + TN) / len(test_dataset)\n",
+ " metrics[threshold] = {\"TP\": TP, \"TN\": TN, \"FP\": FP, \"FN\": FN, \"accuracy\": accuracy, 'F1': F1}\n",
+ "\n",
+ " return metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Perform our final comparison on our models before and after fine tuning"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load the test dataset\n",
+ "test_dataset = load_dataset(\"csv\", data_files=f\"datasets/{test_data}\", split='train')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.load_state_dict(before_training)\n",
+ "metrics_before_training = evaluate_model(model, test_dataset)\n",
+ "\n",
+ "model.load_state_dict(after_training)\n",
+ "metrics_after_training = evaluate_model(model, test_dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Visualize metrics\n",
+ "Since we're tracking multiple metrics - true & false positives & negatives, accuracy and F1 score - we want a way to quickly visually compare all of these.\n",
+ "We'll plot these metrics to see how they change after fine tuning, and also how we can influence them by selecting the best cosine similarity threshold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from matplotlib import pyplot as plt\n",
+ "\n",
+ "def display_AUC(metrics_before, metrics_after):\n",
+ " precision_before = [m['TP'] / (m['TP'] + m['FP']) if (m['TP'] + m['FP']) > 0 else 1 for m in metrics_before.values()]\n",
+ " precision_after = [m['TP'] / (m['TP'] + m['FP']) if (m['TP'] + m['FP']) > 0 else 1 for m in metrics_after.values()]\n",
+ "\n",
+ " recall_before = [m['TP'] / (m['TP'] + m['FN']) if (m['TP'] + m['FN']) > 0 else 1 for m in metrics_before.values()]\n",
+ " recall_after = [m['TP'] / (m['TP'] + m['FN']) if (m['TP'] + m['FN']) > 0 else 1 for m in metrics_after.values()]\n",
+ "\n",
+ " from sklearn.metrics import roc_auc_score\n",
+ " y_true_before = []\n",
+ " y_score_before = []\n",
+ " y_true_after = []\n",
+ " y_score_after = []\n",
+ "\n",
+ " for m in metrics_before.values():\n",
+ " y_true_before.extend([1] * m['TP'] + [0] * m['FN'] + [0] * m['TN'] + [1] * m['FP'])\n",
+ " y_score_before.extend([1] * m['TP'] + [1] * m['FN'] + [0] * m['TN'] + [0] * m['FP'])\n",
+ "\n",
+ " for m in metrics_after.values():\n",
+ " y_true_after.extend([1] * m['TP'] + [0] * m['FN'] + [0] * m['TN'] + [1] * m['FP'])\n",
+ " y_score_after.extend([1] * m['TP'] + [1] * m['FN'] + [0] * m['TN'] + [0] * m['FP'])\n",
+ "\n",
+ " auc_before = roc_auc_score(y_true_before, y_score_before)\n",
+ " auc_after = roc_auc_score(y_true_after, y_score_after)\n",
+ "\n",
+ " plt.figure()\n",
+ " plt.plot(recall_before, precision_before, scalex=False, scaley=False)\n",
+ " plt.plot(recall_after, precision_after, scalex=False, scaley=False)\n",
+ " plt.title(f'trained on {train_data}, test on {test_data}\\n Precision Recall curves with finetuning')\n",
+ " plt.xlabel('Recall')\n",
+ " plt.ylabel('Precision')\n",
+ " plt.ylim([0,1.1])\n",
+ " plt.legend([f'before finetuning auc={auc_before :.4f}', f'after finetuning auc={auc_after :.4f}'])\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def display_accuracy(metrics_before, metrics_after):\n",
+ " accuracy_before = [m['accuracy'] for m in metrics_before.values()]\n",
+ " accuracy_after = [m['accuracy'] for m in metrics_after.values()]\n",
+ " plt.figure()\n",
+ " plt.plot(list(metrics_before.keys()), accuracy_before)\n",
+ " plt.plot(list(metrics_after.keys()), accuracy_after)\n",
+ " plt.title(f'trained on {train_data}, test on {test_data}\\n Accuracy')\n",
+ " plt.xlabel('Threshold')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.ylim([0,1.1])\n",
+ " plt.legend(['before finetuning', 'after finetuning'])\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def display_f1_score(metrics_before, metrics_after):\n",
+ " F1_before = [m[\"F1\"] for m in metrics_before.values()]\n",
+ " F1_after = [m[\"F1\"] for m in metrics_after.values()]\n",
+ "\n",
+ " plt.figure()\n",
+ " plt.plot(list(metrics_before.keys()), F1_before)\n",
+ " plt.plot(list(metrics_after.keys()), F1_after)\n",
+ " plt.title(f'trained on {train_data}, test on {test_data}\\n F1 Score')\n",
+ " plt.xlabel('Threshold')\n",
+ " plt.ylabel('F1 Score')\n",
+ " plt.legend(['before finetuning', 'after finetuning'])\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "display_AUC(metrics_before_training, metrics_after_training)\n",
+ "display_accuracy(metrics_before_training, metrics_after_training)\n",
+ "display_f1_score(metrics_before_training, metrics_after_training)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Every use case is different\n",
+ "With vector embeddings we always have to keep in mind there is a tradeoff between true and false positives and negatives. You can cast a wide net with a large threshold and grab many seemingly similar vectors at the risk of getting some irrelevant ones, or you can be conservative and match only highly similar embeddings, and risk missing something important. You can control this tradeoff by selecting the similarity threshold that makes sense for your system.\n",
+ "\n",
+ "Where you set this threshold depends on your own use case and system, and your tolerance for different types of errors. Choosing the threshold that maximizes F1 score or accuracy are good places to start. Ultimately you'll want to optimize for your specific use case, and we have a [retrieval optimizer tool](https://github.com/redis-applied-ai/retrieval-optimizer) to help with that when you're ready for the next level of system improvements."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Choosing your threshold\n",
+ "To get a sense of how the choice of similarity threshold changes cache performance here's an interactive tool that lets you change the threshold and immediately see how the tradeoff between true and false positives and negatives balances out."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, precision_recall_curve\n",
+ "\n",
+ "def compute_metrics_at_threshold(\n",
+ " scores: np.ndarray,\n",
+ " labels: np.ndarray,\n",
+ " threshold: float,\n",
+ " high_score_more_similar: bool = True\n",
+ "):\n",
+ " if high_score_more_similar:\n",
+ " predictions = (scores >= threshold).astype(int)\n",
+ " else:\n",
+ " predictions = (scores <= threshold).astype(int)\n",
+ "\n",
+ " print(predictions)\n",
+ " precision = precision_score(labels, predictions)\n",
+ " recall = recall_score(labels, predictions)\n",
+ " f1 = f1_score(labels, predictions)\n",
+ " cm = confusion_matrix(labels, predictions)\n",
+ "\n",
+ " return {'precision': precision, 'recall': recall, 'f1_score': f1, 'confusion_matrix': cm}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.metrics.pairwise import cosine_similarity\n",
+ "\n",
+ "q1_embeddings = [model.encode(pair['question_1']) for pair in test_dataset]\n",
+ "q2_embeddings = [model.encode(pair['question_2']) for pair in test_dataset]\n",
+ "cosine_similarities = np.array([cosine_similarity([emb1], [emb2])[0][0] for emb1, emb2 in zip(q1_embeddings, q2_embeddings)])\n",
+ "labels = np.array(test_dataset[\"label\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "663bbe8f3bd34492a26b59566de2a926",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "interactive(children=(FloatSlider(value=0.8114206194877625, continuous_update=False, description='Cosine Simil…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import seaborn as sns\n",
+ "from ipywidgets import FloatSlider, Layout, interact\n",
+ "from IPython.display import display, HTML\n",
+ "\n",
+ "\n",
+ "def update_plots(threshold):\n",
+ " # set a pleasing style and update global font sizes\n",
+ " plt.rcParams.update({'font.size': 16})\n",
+ "\n",
+ " metrics = compute_metrics_at_threshold(cosine_similarities, labels, threshold, high_score_more_similar=True)\n",
+ " precision = metrics['precision']\n",
+ " recall_val = metrics['recall']\n",
+ " f1 = metrics['f1_score']\n",
+ " cm = metrics['confusion_matrix']\n",
+ "\n",
+ " precision_curve, recall_curve, pr_thresholds = precision_recall_curve(labels, cosine_similarities)\n",
+ "\n",
+ " # clear previous plots\n",
+ " plt.clf()\n",
+ "\n",
+ " # create subplots with a larger figure size for better readability\n",
+ " fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
+ "\n",
+ " # Precision-Recall curve plot\n",
+ " axs[0].plot(recall_curve, precision_curve, color='blue', linewidth=2, label='Precision-Recall Curve')\n",
+ " axs[0].scatter(recall_val, precision, color='red', s=100, zorder=5,\n",
+ " label=(f'Threshold = {threshold:.4f}\\n'\n",
+ " f'Precision = {precision:.2f}\\n'\n",
+ " f'Recall = {recall_val:.2f}'))\n",
+ " axs[0].set_title('Precision-Recall Curve', fontsize=20, fontweight='bold')\n",
+ " axs[0].set_xlabel('Recall', fontsize=18)\n",
+ " axs[0].set_ylabel('Precision', fontsize=18)\n",
+ " axs[0].tick_params(axis='both', labelsize=16)\n",
+ " axs[0].legend(fontsize=14)\n",
+ " axs[0].grid(True, linestyle='--', alpha=0.7)\n",
+ "\n",
+ " # confusion matrix heatmap\n",
+ " sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axs[1],\n",
+ " cbar=True, annot_kws={'size': 16})\n",
+ " axs[1].set_title('Confusion Matrix', fontsize=20, fontweight='bold')\n",
+ " axs[1].set_xlabel('Predicted Label', fontsize=18)\n",
+ " axs[1].set_ylabel('True Label', fontsize=18)\n",
+ " axs[1].set_xticklabels(['Dissimilar (0)', 'Similar (1)'], fontsize=16)\n",
+ " axs[1].set_yticklabels(['Dissimilar (0)', 'Similar (1)'], fontsize=16, rotation=0)\n",
+ "\n",
+ " # overall figure title with metrics\n",
+ " fig.suptitle(\n",
+ " (f'Cosine Similarity Threshold: {threshold:.4f}\\n'\n",
+ " f'Precision: {precision:.2f}, Recall: {recall_val:.2f}, F1 Score: {f1:.2f}'),\n",
+ " fontsize=12, fontweight='bold'\n",
+ " )\n",
+ "\n",
+ " plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
+ " plt.show()\n",
+ "\n",
+ "# add some CSS to increase the font size for the slider's description and readout\n",
+ "display(HTML(\"\"\"\n",
+ "\n",
+ "\"\"\"))\n",
+ "\n",
+ "# add a slider with the new description and custom styling\n",
+ "threshold_slider = FloatSlider(\n",
+ " value=np.median(cosine_similarities),\n",
+ " min=np.min(cosine_similarities),\n",
+ " max=np.max(cosine_similarities),\n",
+ " step=0.001,\n",
+ " description='Cosine Similarity Threshold:',\n",
+ " readout=True,\n",
+ " readout_format='.4f',\n",
+ " continuous_update=False,\n",
+ " style={'description_width': 'initial'},\n",
+ " layout=Layout(width='80%', margin='20px 0px 20px 0px')\n",
+ ")\n",
+ "\n",
+ "# add a custom class to the slider for our CSS targeting\n",
+ "threshold_slider.add_class(\"custom-slider\")\n",
+ "interact(update_plots, threshold=threshold_slider)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "redis-ai-res",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}