-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MariaDB introduced vector support in version 11.7, enabling MariaDB Server to function as a relational vector database. https://mariadb.com/kb/en/vectors/ Now add support for MariaDB server, verified against MariaDB server of version 11.7.1: - Support MariaDB vector search with HNSW algorithm, support filter search. - Support index and search parameters: - storage_engine: InnoDB or MyISAM - M: M parameter in MHNSW vector indexing - ef_search: minimal number of result candidates to look for in the vector index for ORDER BY ... LIMIT N queries. - max_cache_size: Upper limit for one MHNSW vector index cache - Support CLI of `vectordbbench mariadbhnsw`.
- Loading branch information
Showing
8 changed files
with
488 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Annotated, Optional, Unpack | ||
|
||
import click | ||
import os | ||
from pydantic import SecretStr | ||
|
||
from ....cli.cli import ( | ||
CommonTypedDict, | ||
HNSWFlavor1, | ||
cli, | ||
click_parameter_decorators_from_typed_dict, | ||
run, | ||
) | ||
from vectordb_bench.backend.clients import DB | ||
|
||
|
||
class MariaDBTypedDict(CommonTypedDict): | ||
user_name: Annotated[ | ||
str, click.option("--username", | ||
type=str, | ||
help="Username", | ||
required=True, | ||
), | ||
] | ||
password: Annotated[ | ||
str, click.option("--password", | ||
type=str, | ||
help="Password", | ||
required=True, | ||
), | ||
] | ||
|
||
host: Annotated[ | ||
str, click.option("--host", | ||
type=str, | ||
help="Db host", | ||
default="127.0.0.1", | ||
), | ||
] | ||
|
||
port: Annotated[ | ||
int, click.option("--port", | ||
type=int, | ||
default=3306, | ||
help="Db Port", | ||
), | ||
] | ||
|
||
storage_engine: Annotated[ | ||
int, click.option("--storage-engine", | ||
type=click.Choice(["InnoDB", "MyISAM"]), | ||
help="DB storage engine", | ||
required=True, | ||
), | ||
] | ||
|
||
class MariaDBHNSWTypedDict(MariaDBTypedDict): | ||
... | ||
m: Annotated[ | ||
Optional[int], click.option("--m", | ||
type=int, | ||
help="M parameter in MHNSW vector indexing", | ||
required=False, | ||
), | ||
] | ||
|
||
ef_search: Annotated[ | ||
Optional[int], click.option("--ef-search", | ||
type=int, | ||
help="MariaDB system variable mhnsw_min_limit", | ||
required=False, | ||
), | ||
] | ||
|
||
max_cache_size: Annotated[ | ||
Optional[int], click.option("--max-cache-size", | ||
type=int, | ||
help="MariaDB system variable mhnsw_max_cache_size", | ||
required=False, | ||
), | ||
] | ||
|
||
|
||
@cli.command() | ||
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) | ||
def MariaDBHNSW( | ||
**parameters: Unpack[MariaDBHNSWTypedDict], | ||
): | ||
from .config import MariaDBConfig, MariaDBHNSWConfig | ||
|
||
run( | ||
db=DB.MariaDB, | ||
db_config=MariaDBConfig( | ||
db_label=parameters["db_label"], | ||
user_name=parameters["username"], | ||
password=SecretStr(parameters["password"]), | ||
host=parameters["host"], | ||
port=parameters["port"], | ||
), | ||
db_case_config=MariaDBHNSWConfig( | ||
M=parameters["m"], | ||
ef_search=parameters["ef_search"], | ||
storage_engine=parameters["storage_engine"], | ||
max_cache_size=parameters["max_cache_size"], | ||
), | ||
**parameters, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from pydantic import SecretStr, BaseModel | ||
from typing import TypedDict | ||
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType | ||
|
||
class MariaDBConfigDict(TypedDict): | ||
"""These keys will be directly used as kwargs in mariadb connection string, | ||
so the names must match exactly mariadb API""" | ||
|
||
user: str | ||
password: str | ||
host: str | ||
port: int | ||
|
||
|
||
class MariaDBConfig(DBConfig): | ||
user_name: str = "root" | ||
password: SecretStr | ||
host: str = "127.0.0.1" | ||
port: int = 3306 | ||
|
||
def to_dict(self) -> MariaDBConfigDict: | ||
pwd_str = self.password.get_secret_value() | ||
return { | ||
"host": self.host, | ||
"port": self.port, | ||
"user": self.user_name, | ||
"password": pwd_str, | ||
} | ||
|
||
|
||
class MariaDBIndexConfig(BaseModel): | ||
"""Base config for MariaDB""" | ||
|
||
metric_type: MetricType | None = None | ||
|
||
def parse_metric(self) -> str: | ||
if self.metric_type == MetricType.L2: | ||
return "euclidean" | ||
elif self.metric_type == MetricType.COSINE: | ||
return "cosine" | ||
else: | ||
raise ValueError(f"Metric type {self.metric_type} is not supported!") | ||
|
||
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): | ||
M: int | None | ||
ef_search: int | None | ||
index: IndexType = IndexType.HNSW | ||
storage_engine: str = "InnoDB" | ||
max_cache_size: int | None | ||
|
||
def index_param(self) -> dict: | ||
return { | ||
"storage_engine": self.storage_engine, | ||
"metric_type": self.parse_metric(), | ||
"index_type": self.index.value, | ||
"M": self.M, | ||
"max_cache_size": self.max_cache_size, | ||
} | ||
|
||
def search_param(self) -> dict: | ||
return { | ||
"metric_type": self.parse_metric(), | ||
"ef_search": self.ef_search, | ||
} | ||
|
||
|
||
_mariadb_case_config = { | ||
IndexType.HNSW: MariaDBHNSWConfig, | ||
} | ||
|
||
|
Oops, something went wrong.