Skip to content

Commit

Permalink
collab: Setup database for LLM service (#15882)
Browse files Browse the repository at this point in the history
This PR puts the initial infrastructure for the LLM service's database
in place.

The LLM service will be using a separate Postgres database, with its own
set of migrations.

Currently we only connect to the database in development, as we don't
yet have the database setup for the staging/production environments.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Aug 6, 2024
1 parent a649067 commit 7f6d091
Show file tree
Hide file tree
Showing 25 changed files with 627 additions and 74 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ RUN apt-get update; \
WORKDIR app
COPY --from=builder /app/collab /app/collab
COPY --from=builder /app/crates/collab/migrations /app/migrations
COPY --from=builder /app/crates/collab/migrations_llm /app/migrations_llm
ENV MIGRATIONS_PATH=/app/migrations
ENV LLM_DATABASE_MIGRATIONS_PATH=/app/migrations_llm
ENTRYPOINT ["/app/collab"]
2 changes: 2 additions & 0 deletions crates/collab/.env.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
SEED_PATH = "crates/collab/seed.default.json"
LLM_DATABASE_URL = "postgres://postgres@localhost/zed_llm"
LLM_DATABASE_MAX_CONNECTIONS = 5
LLM_API_SECRET = "llm-secret"

# CLICKHOUSE_URL = ""
Expand Down
16 changes: 16 additions & 0 deletions crates/collab/migrations_llm.sqlite/20240806182921_test_schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
create table providers (
id integer primary key autoincrement,
name text not null
);

create unique index uix_providers_on_name on providers (name);

create table models (
id integer primary key autoincrement,
provider_id integer not null references providers (id) on delete cascade,
name text not null
);

create unique index uix_models_on_provider_id_name on models (provider_id, name);
create index ix_models_on_provider_id on models (provider_id);
create index ix_models_on_name on models (name);
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
create table if not exists providers (
id serial primary key,
name text not null
);

create unique index uix_providers_on_name on providers (name);

create table if not exists models (
id serial primary key,
provider_id integer not null references providers (id) on delete cascade,
name text not null
);

create unique index uix_models_on_provider_id_name on models (provider_id, name);
create index ix_models_on_provider_id on models (provider_id);
create index ix_models_on_name on models (name);
53 changes: 5 additions & 48 deletions crates/collab/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,12 @@ use sea_orm::{
};
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use sqlx::{
migrate::{Migrate, Migration, MigrationSource},
Connection,
};
use std::ops::RangeInclusive;
use std::{
fmt::Write as _,
future::Future,
marker::PhantomData,
ops::{Deref, DerefMut},
path::Path,
rc::Rc,
sync::Arc,
time::Duration,
Expand Down Expand Up @@ -90,54 +85,16 @@ impl Database {
})
}

pub fn options(&self) -> &ConnectOptions {
&self.options
}

#[cfg(test)]
pub fn reset(&self) {
self.rooms.clear();
self.projects.clear();
}

/// Runs the database migrations.
pub async fn migrate(
&self,
migrations_path: &Path,
ignore_checksum_mismatch: bool,
) -> anyhow::Result<Vec<(Migration, Duration)>> {
let migrations = MigrationSource::resolve(migrations_path)
.await
.map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;

let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;

connection.ensure_migrations_table().await?;
let applied_migrations: HashMap<_, _> = connection
.list_applied_migrations()
.await?
.into_iter()
.map(|m| (m.version, m))
.collect();

let mut new_migrations = Vec::new();
for migration in migrations {
match applied_migrations.get(&migration.version) {
Some(applied_migration) => {
if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
{
Err(anyhow!(
"checksum mismatch for applied migration {}",
migration.description
))?;
}
}
None => {
let elapsed = connection.apply(&migration).await?;
new_migrations.push((migration, elapsed));
}
}
}

Ok(new_migrations)
}

/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
Expand Down Expand Up @@ -453,7 +410,7 @@ fn is_serialization_error(error: &Error) -> bool {
}

/// A handle to a [`DatabaseTransaction`].
pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
pub struct TransactionHandle(pub(crate) Arc<Option<DatabaseTransaction>>);

impl Deref for TransactionHandle {
type Target = DatabaseTransaction;
Expand Down
1 change: 1 addition & 0 deletions crates/collab/src/db/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use rpc::proto;
use sea_orm::{entity::prelude::*, DbErr};
use serde::{Deserialize, Serialize};

#[macro_export]
macro_rules! id_type {
($name:ident) => {
#[derive(
Expand Down
6 changes: 5 additions & 1 deletion crates/collab/src/db/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;

use crate::migrations::run_database_migrations;

use super::*;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
Expand Down Expand Up @@ -91,7 +93,9 @@ impl TestDb {
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
db.migrate(Path::new(migrations_path), false).await.unwrap();
run_database_migrations(db.options(), migrations_path, false)
.await
.unwrap();
db.initialize_notification_kinds().await.unwrap();
db
});
Expand Down
7 changes: 7 additions & 0 deletions crates/collab/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod db;
pub mod env;
pub mod executor;
pub mod llm;
pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
Expand Down Expand Up @@ -150,6 +151,9 @@ pub struct Config {
pub live_kit_server: Option<String>,
pub live_kit_key: Option<String>,
pub live_kit_secret: Option<String>,
pub llm_database_url: Option<String>,
pub llm_database_max_connections: Option<u32>,
pub llm_database_migrations_path: Option<PathBuf>,
pub llm_api_secret: Option<String>,
pub rust_log: Option<String>,
pub log_json: Option<bool>,
Expand Down Expand Up @@ -197,6 +201,9 @@ impl Config {
live_kit_server: None,
live_kit_key: None,
live_kit_secret: None,
llm_database_url: None,
llm_database_max_connections: None,
llm_database_migrations_path: None,
llm_api_secret: None,
rust_log: None,
log_json: None,
Expand Down
25 changes: 24 additions & 1 deletion crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod authorization;
pub mod db;
mod token;

use crate::api::CloudflareIpCountryHeader;
use crate::llm::authorization::authorize_access_to_language_model;
use crate::llm::db::LlmDatabase;
use crate::{executor::Executor, Config, Error, Result};
use anyhow::Context as _;
use anyhow::{anyhow, Context as _};
use axum::TypedHeader;
use axum::{
body::Body,
Expand All @@ -24,11 +26,31 @@ pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Option<Arc<LlmDatabase>>,
pub http_client: IsahcHttpClient,
}

impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
// TODO: This is temporary until we have the LLM database stood up.
let db = if config.is_development() {
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;

let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let db = LlmDatabase::new(db_options, executor.clone()).await?;

Some(Arc::new(db))
} else {
None
};

let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
Expand All @@ -38,6 +60,7 @@ impl LlmState {
let this = Self {
config,
executor,
db,
http_client,
};

Expand Down
118 changes: 118 additions & 0 deletions crates/collab/src/llm/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
mod ids;
mod queries;
mod tables;

#[cfg(test)]
mod tests;

pub use ids::*;
pub use tables::*;

#[cfg(test)]
pub use tests::TestLlmDb;

use std::future::Future;
use std::sync::Arc;

use anyhow::anyhow;
use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions;
use sea_orm::{
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
};

use crate::db::TransactionHandle;
use crate::executor::Executor;
use crate::Result;

/// The database for the LLM service.
pub struct LlmDatabase {
options: ConnectOptions,
pool: DatabaseConnection,
#[allow(unused)]
executor: Executor,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}

impl LlmDatabase {
/// Connects to the database with the given options
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
sqlx::any::install_default_drivers();
Ok(Self {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
executor,
#[cfg(test)]
runtime: None,
})
}

pub fn options(&self) -> &ConnectOptions {
&self.options
}

pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(result),
Err(error) => {
return Err(error);
}
},
Err(error) => {
tx.rollback().await?;
return Err(error);
}
}
};

self.run(body).await
}

async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let tx = self
.pool
.begin_with_config(Some(IsolationLevel::ReadCommitted), None)
.await?;

let mut tx = Arc::new(Some(tx));
let result = f(TransactionHandle(tx.clone())).await;
let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
return Err(anyhow!(
"couldn't complete transaction because it's still in use"
))?;
};

Ok((tx, result))
}

async fn run<F, T>(&self, future: F) -> Result<T>
where
F: Future<Output = Result<T>>,
{
#[cfg(test)]
{
if let Executor::Deterministic(executor) = &self.executor {
executor.simulate_random_delay().await;
}

self.runtime.as_ref().unwrap().block_on(future)
}

#[cfg(not(test))]
{
future.await
}
}
}
7 changes: 7 additions & 0 deletions crates/collab/src/llm/db/ids.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use sea_orm::{entity::prelude::*, DbErr};
use serde::{Deserialize, Serialize};

use crate::id_type;

id_type!(ProviderId);
id_type!(ModelId);
3 changes: 3 additions & 0 deletions crates/collab/src/llm/db/queries.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use super::*;

pub mod providers;
Loading

0 comments on commit 7f6d091

Please sign in to comment.