diff --git a/dev_notes.md b/dev_notes.md index 52b0b571d..d265c3ba3 100644 --- a/dev_notes.md +++ b/dev_notes.md @@ -2,40 +2,21 @@ ## CURRENT WORK -- [x] migrate `api` -- [x] migrate `api_types` -- [x] migrate `bin` -- [x] migrate `common` -- [x] migrate `error` -- [x] migrate `middlewares` -- [ ] migrate `models` (below-mentioned migration/* and app_state missing) -- [x] migrate `notify` -- [x] migrate `schedulers` -- [x] migrate `service` - -- migrate each model step by step - create proper direct query for users in `src/schedulers/src/passwords.rs` -- `MIGRATE_DB_FROM` for Hiqlite -> implement backup restore from local fs in Hiqlite -- modules left for the end, after main tasks are finished: - - `src/models/src/migration/mod.rs` - - `src/models/src/migration/db_migrate.rs` - - `src/models/src/app_state.rs` ### After finished Hiqlite migration - check changed session invalidation functions - fix `DbType::from_str` -- cleanup `DbPool` creation in `AppState` -- remove the `sqlite` feature from `sqlx` to really make sure nothing has been forgotten - add an index (signature, created_at) to `jwks` -#### Update for the Changelog - -- POST /clients does not return the created client anymore - ## Documentation TODO +- breaking: only a single container from now on - `HealthResponse` response has been changed with Hiqlite -> breaking change +- database backup config has been changed slightly +- restore from backup has changed slightly +- write a small guide on how to migrate from existing sqlite to hiqlite ## Stage 1 - essentials @@ -59,6 +40,4 @@ ## Stage 3 - Possible nice to haves - impl experimental `dilithium` alg for token signing to become quantum safe -- 'rauthy-migrate' project to help migrating to rauthy? probably when doing benchmarks anyway and use it - for dummy data? -- custom event listener template to build own implementation? -> only if NATS will be implemented maybe? +- custom event listener template to build own implementation? diff --git a/justfile b/justfile index 36929b2ce..a88536e13 100644 --- a/justfile +++ b/justfile @@ -266,6 +266,8 @@ test-postgres test="": test-backend-stop postgres-stop postgres-start ./target/debug/rauthy test & echo $! > {{ file_test_pid }} + sleep 1 + if cargo test {{ test }}; then echo "All SQLite tests successful" just test-backend-stop diff --git a/rauthy.test.cfg b/rauthy.test.cfg index 111959ff9..8ced4c522 100644 --- a/rauthy.test.cfg +++ b/rauthy.test.cfg @@ -6,6 +6,8 @@ # !!! DO NOT USE IN PRODUCTION !!! DEV_MODE=true +DATABASE_URL=postgresql://rauthy:123SuperSafe@localhost:5432/rauthy + AUTH_HEADERS_ENABLE=true PASSWORD_RESET_COOKIE_BINDING=true diff --git a/src/api/src/api_keys.rs b/src/api/src/api_keys.rs index 49c59e75f..07227578b 100644 --- a/src/api/src/api_keys.rs +++ b/src/api/src/api_keys.rs @@ -4,7 +4,6 @@ use actix_web_validator::Json; use mime_guess::mime::TEXT_PLAIN_UTF_8; use rauthy_api_types::api_keys::{ApiKeyRequest, ApiKeyResponse, ApiKeysResponse}; use rauthy_error::{ErrorResponse, ErrorResponseType}; -use rauthy_models::app_state::AppState; use rauthy_models::entity::api_keys::ApiKeyEntity; /// Returns all API Keys @@ -22,13 +21,10 @@ use rauthy_models::entity::api_keys::ApiKeyEntity; ), )] #[get("/api_keys")] -pub async fn get_api_keys( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_api_keys(principal: ReqPrincipal) -> Result { principal.validate_admin_session()?; - let entities = ApiKeyEntity::find_all(&data).await?; + let entities = ApiKeyEntity::find_all().await?; let mut keys = Vec::with_capacity(entities.len()); for entity in entities { let key = entity.into_api_key()?; @@ -86,7 +82,6 @@ pub async fn post_api_key( )] #[put("/api_keys/{name}")] pub async fn put_api_key( - data: web::Data, principal: ReqPrincipal, name: web::Path, payload: Json, @@ -103,7 +98,7 @@ pub async fn put_api_key( } let access = req.access.into_iter().map(|a| a.into()).collect(); - ApiKeyEntity::update(&data, &name, req.exp, access).await?; + ApiKeyEntity::update(&name, req.exp, access).await?; Ok(HttpResponse::Ok().finish()) } @@ -124,14 +119,13 @@ pub async fn put_api_key( )] #[delete("/api_keys/{name}")] pub async fn delete_api_key( - data: web::Data, principal: ReqPrincipal, name: web::Path, ) -> Result { principal.validate_admin_session()?; let name = name.into_inner(); - ApiKeyEntity::delete(&data, &name).await?; + ApiKeyEntity::delete(&name).await?; Ok(HttpResponse::Ok().finish()) } @@ -190,14 +184,13 @@ pub async fn get_api_key_test( )] #[put("/api_keys/{name}/secret")] pub async fn put_api_key_secret( - data: web::Data, principal: ReqPrincipal, name: web::Path, ) -> Result { principal.validate_admin_session()?; let name = name.into_inner(); - let secret = ApiKeyEntity::generate_secret(&data, &name).await?; + let secret = ApiKeyEntity::generate_secret(&name).await?; Ok(HttpResponse::Ok() .content_type(TEXT_PLAIN_UTF_8) diff --git a/src/api/src/auth_providers.rs b/src/api/src/auth_providers.rs index a1c9020d9..325a4c87b 100644 --- a/src/api/src/auth_providers.rs +++ b/src/api/src/auth_providers.rs @@ -35,13 +35,10 @@ use tracing::debug; ), )] #[post("/providers")] -pub async fn post_providers( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn post_providers(principal: ReqPrincipal) -> Result { principal.validate_admin_session()?; - let providers = AuthProvider::find_all(&data).await?; + let providers = AuthProvider::find_all().await?; let mut resp = Vec::with_capacity(providers.len()); for provider in providers { resp.push(ProviderResponse::try_from(provider)?); @@ -66,7 +63,6 @@ pub async fn post_providers( )] #[post("/providers/create")] pub async fn post_provider( - data: web::Data, payload: Json, principal: ReqPrincipal, ) -> Result { @@ -79,7 +75,7 @@ pub async fn post_provider( )); } - let provider = AuthProvider::create(&data, payload.into_inner()).await?; + let provider = AuthProvider::create(payload.into_inner()).await?; Ok(HttpResponse::Ok().json(ProviderResponse::try_from(provider)?)) } @@ -133,14 +129,13 @@ pub async fn post_provider_lookup( )] #[post("/providers/login")] pub async fn post_provider_login( - data: web::Data, payload: Json, principal: ReqPrincipal, ) -> Result { principal.validate_session_auth_or_init()?; let payload = payload.into_inner(); - let (cookie, xsrf_token, location) = AuthProviderCallback::login_start(&data, payload).await?; + let (cookie, xsrf_token, location) = AuthProviderCallback::login_start(payload).await?; Ok(HttpResponse::Accepted() .insert_header((LOCATION, location)) @@ -149,11 +144,8 @@ pub async fn post_provider_login( } #[get("/providers/callback")] -pub async fn get_provider_callback_html( - data: web::Data, - req: HttpRequest, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_provider_callback_html(req: HttpRequest) -> Result { + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); let body = ProviderCallbackHtml::build(&colors, &lang); @@ -218,14 +210,11 @@ pub async fn post_provider_callback( ), )] #[delete("/providers/link")] -pub async fn delete_provider_link( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn delete_provider_link(principal: ReqPrincipal) -> Result { principal.validate_session_auth()?; let user_id = principal.user_id()?.to_string(); - let user = User::provider_unlink(&data, user_id).await?; + let user = User::provider_unlink(user_id).await?; Ok(HttpResponse::Ok().json(user)) } @@ -242,12 +231,10 @@ pub async fn delete_provider_link( ), )] #[get("/providers/minimal")] -pub async fn get_providers_minimal( - data: web::Data, -) -> Result { +pub async fn get_providers_minimal() -> Result { // unauthorized - does not leak any sensitive information other than shown in the // default login page anyway - match AuthProviderTemplate::get_all_json_template(&data).await? { + match AuthProviderTemplate::get_all_json_template().await? { None => Ok(HttpResponse::Ok().insert_header(HEADER_JSON).body("[]")), Some(tpl) => Ok(HttpResponse::Ok().insert_header(HEADER_JSON).body(tpl)), } @@ -268,7 +255,6 @@ pub async fn get_providers_minimal( )] #[put("/providers/{id}")] pub async fn put_provider( - data: web::Data, id: web::Path, payload: Json, principal: ReqPrincipal, @@ -282,7 +268,7 @@ pub async fn put_provider( )); } - AuthProvider::update(&data, id.into_inner(), payload.into_inner()).await?; + AuthProvider::update(id.into_inner(), payload.into_inner()).await?; Ok(HttpResponse::Ok().finish()) } @@ -300,13 +286,12 @@ pub async fn put_provider( )] #[delete("/providers/{id}")] pub async fn delete_provider( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_admin_session()?; - AuthProvider::delete(&data, &id.into_inner()).await?; + AuthProvider::delete(&id.into_inner()).await?; Ok(HttpResponse::Ok().finish()) } @@ -326,13 +311,12 @@ pub async fn delete_provider( )] #[get("/providers/{id}/delete_safe")] pub async fn get_provider_delete_safe( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_admin_session()?; - let linked_users = AuthProvider::find_linked_users(&data, &id.into_inner()).await?; + let linked_users = AuthProvider::find_linked_users(&id.into_inner()).await?; if linked_users.is_empty() { Ok(HttpResponse::Ok().json(linked_users)) } else { @@ -350,12 +334,9 @@ pub async fn get_provider_delete_safe( ), )] #[get("/providers/{id}/img")] -pub async fn get_provider_img( - data: web::Data, - id: web::Path, -) -> Result { +pub async fn get_provider_img(id: web::Path) -> Result { let id = id.into_inner(); - let logo = Logo::find_cached(&data, &id, &LogoType::AuthProvider).await?; + let logo = Logo::find_cached(&id, &LogoType::AuthProvider).await?; Ok(HttpResponse::Ok() .insert_header((CONTENT_TYPE, logo.content_type)) @@ -383,7 +364,6 @@ pub async fn get_provider_img( )] #[put("/providers/{id}/img")] pub async fn put_provider_img( - data: web::Data, id: web::Path, principal: ReqPrincipal, mut payload: actix_multipart::Multipart, @@ -417,7 +397,6 @@ pub async fn put_provider_img( // content_type unwrap cannot panic -> checked above Logo::upsert( - &data, id.into_inner(), buf, content_type.unwrap(), @@ -445,7 +424,6 @@ pub async fn put_provider_img( )] #[post("/providers/{id}/link")] pub async fn post_provider_link( - data: web::Data, provider_id: web::Path, principal: ReqPrincipal, payload: Json, @@ -453,7 +431,7 @@ pub async fn post_provider_link( principal.validate_session_auth()?; let user_id = principal.user_id()?.to_string(); - let user = User::find(&data, user_id).await?; + let user = User::find(user_id).await?; // make sure the user is currently un-linked if user.auth_provider_id.is_some() { @@ -472,7 +450,7 @@ pub async fn post_provider_link( // directly redirect to the provider login page let (login_cookie, xsrf_token, location) = - AuthProviderCallback::login_start(&data, payload.into_inner()).await?; + AuthProviderCallback::login_start(payload.into_inner()).await?; Ok(HttpResponse::Accepted() .insert_header((LOCATION, location)) diff --git a/src/api/src/clients.rs b/src/api/src/clients.rs index c46d816c9..bec0429a1 100644 --- a/src/api/src/clients.rs +++ b/src/api/src/clients.rs @@ -38,13 +38,10 @@ use tracing::debug; )] #[tracing::instrument(skip_all)] #[get("/clients")] -pub async fn get_clients( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_clients(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Read)?; - let clients = Client::find_all(&data).await?; + let clients = Client::find_all().await?; let mut res = Vec::new(); clients @@ -73,12 +70,11 @@ pub async fn get_clients( #[get("/clients/{id}")] pub async fn get_client_by_id( path: web::Path, - data: web::Data, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Read)?; - Client::find(&data, path.into_inner()) + Client::find(path.into_inner()) .await .map(|c| HttpResponse::Ok().json(ClientResponse::from(c))) } @@ -104,13 +100,12 @@ pub async fn get_client_by_id( )] #[post("/clients/{id}/secret")] pub async fn get_client_secret( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Secrets, AccessRights::Read)?; - client::get_client_secret(path.into_inner(), &data) + client::get_client_secret(path.into_inner()) .await .map(|c| HttpResponse::Ok().json(c)) } @@ -135,12 +130,11 @@ pub async fn get_client_secret( #[post("/clients")] pub async fn post_clients( client: actix_web_validator::Json, - data: web::Data, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Create)?; - let client = Client::create(&data, client.into_inner()).await?; + let client = Client::create(client.into_inner()).await?; Ok(HttpResponse::Ok().json(ClientResponse::from(client))) } @@ -217,10 +211,10 @@ pub async fn get_clients_dyn( let bearer = helpers::get_bearer_token_from_header(req.headers())?; let id = id.into_inner(); - let client_dyn = ClientDyn::find(&data, id.clone()).await?; + let client_dyn = ClientDyn::find(id.clone()).await?; client_dyn.validate_token(&bearer)?; - let client = Client::find(&data, id).await?; + let client = Client::find(id).await?; let resp = client.into_dynamic_client_response(&data, client_dyn, false)?; Ok(HttpResponse::Ok().json(resp)) } @@ -251,7 +245,7 @@ pub async fn put_clients_dyn( let bearer = helpers::get_bearer_token_from_header(req.headers())?; let id = id.into_inner(); - let client_dyn = ClientDyn::find(&data, id.clone()).await?; + let client_dyn = ClientDyn::find(id.clone()).await?; client_dyn.validate_token(&bearer)?; let resp = Client::update_dynamic(&data, payload.into_inner(), client_dyn).await?; @@ -277,14 +271,13 @@ pub async fn put_clients_dyn( )] #[put("/clients/{id}")] pub async fn put_clients( - data: web::Data, client: actix_web_validator::Json, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Update)?; - client::update_client(&data, path.into_inner(), client.into_inner()) + client::update_client(path.into_inner(), client.into_inner()) .await .map(|r| HttpResponse::Ok().json(ClientResponse::from(r))) } @@ -305,13 +298,12 @@ pub async fn put_clients( )] #[get("/clients/{id}/colors")] pub async fn get_client_colors( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Read)?; - ColorEntity::find(&data, id.as_str()) + ColorEntity::find(id.as_str()) .await .map(|c| HttpResponse::Ok().json(c)) } @@ -334,7 +326,6 @@ pub async fn get_client_colors( )] #[put("/clients/{id}/colors")] pub async fn put_client_colors( - data: web::Data, id: web::Path, principal: ReqPrincipal, req_data: actix_web_validator::Json, @@ -343,7 +334,7 @@ pub async fn put_client_colors( let colors = req_data.into_inner(); colors.validate_css()?; - ColorEntity::update(&data, id.as_str(), colors).await?; + ColorEntity::update(id.as_str(), colors).await?; Ok(HttpResponse::Ok().finish()) } @@ -364,13 +355,12 @@ pub async fn put_client_colors( )] #[delete("/clients/{id}/colors")] pub async fn delete_client_colors( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Delete)?; - ColorEntity::delete(&data, id.as_str()).await?; + ColorEntity::delete(id.as_str()).await?; Ok(HttpResponse::Ok().finish()) } @@ -388,19 +378,16 @@ pub async fn delete_client_colors( ), )] #[get("/clients/{id}/logo")] -pub async fn get_client_logo( - data: web::Data, - id: web::Path, -) -> Result { +pub async fn get_client_logo(id: web::Path) -> Result { let id = id.into_inner(); debug!("Looking up client logo for id {}", id); - let logo = match Logo::find_cached(&data, &id, &LogoType::Client).await { + let logo = match Logo::find_cached(&id, &LogoType::Client).await { Ok(logo) => logo, Err(_) => { debug!("no specific logo for id {} - using Rauthy default", id); // If this client does not have a custom logo, we will always serve // Rauthy's logo as default - Logo::find_cached(&data, "rauthy", &LogoType::Client).await? + Logo::find_cached("rauthy", &LogoType::Client).await? } }; @@ -430,7 +417,6 @@ pub async fn get_client_logo( )] #[put("/clients/{id}/logo")] pub async fn put_client_logo( - data: web::Data, id: web::Path, principal: ReqPrincipal, mut payload: actix_multipart::Multipart, @@ -464,7 +450,6 @@ pub async fn put_client_logo( // content_type unwrap cannot panic -> checked above Logo::upsert( - &data, id.into_inner(), buf, content_type.unwrap(), @@ -490,16 +475,15 @@ delete, )] #[delete("/clients/{id}/logo")] pub async fn delete_client_logo( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Clients, AccessRights::Delete)?; if id.as_str() == "rauthy" { - Logo::upsert_rauthy_default(&data).await?; + Logo::upsert_rauthy_default().await?; } else { - Logo::delete(&data, id.as_str(), &LogoType::Client).await?; + Logo::delete(id.as_str(), &LogoType::Client).await?; } Ok(HttpResponse::Ok().finish()) @@ -526,13 +510,12 @@ pub async fn delete_client_logo( )] #[put("/clients/{id}/secret")] pub async fn put_generate_client_secret( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Secrets, AccessRights::Update)?; - client::generate_new_secret(id.into_inner(), &data) + client::generate_new_secret(id.into_inner()) .await .map(|r| HttpResponse::Ok().json(r)) } @@ -555,7 +538,6 @@ pub async fn put_generate_client_secret( )] #[delete("/clients/{id}")] pub async fn delete_client( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { @@ -570,7 +552,7 @@ pub async fn delete_client( )); } - let client = Client::find(&data, id).await?; - client.delete(&data).await?; + let client = Client::find(id).await?; + client.delete().await?; Ok(HttpResponse::Ok().finish()) } diff --git a/src/api/src/events.rs b/src/api/src/events.rs index 30331f945..cc0843694 100644 --- a/src/api/src/events.rs +++ b/src/api/src/events.rs @@ -29,7 +29,6 @@ use validator::Validate; )] #[post("/events")] pub async fn post_events( - data: web::Data, principal: ReqPrincipal, payload: Json, ) -> Result { @@ -39,7 +38,6 @@ pub async fn post_events( let payload = payload.into_inner(); let events = Event::find_all( - &data.db, payload.from, payload.until.unwrap_or_else(|| Utc::now().timestamp()), payload.level.into(), diff --git a/src/api/src/fed_cm.rs b/src/api/src/fed_cm.rs index 8d3cf4ee7..d7fa5d435 100644 --- a/src/api/src/fed_cm.rs +++ b/src/api/src/fed_cm.rs @@ -39,14 +39,11 @@ const HEADER_ALLOW_CREDENTIALS: (&str, &str) = ("access-control-allow-credential )] #[get("/fed_cm/accounts")] #[tracing::instrument(level = "debug", skip_all)] -pub async fn get_fed_cm_accounts( - req: HttpRequest, - data: web::Data, -) -> Result { +pub async fn get_fed_cm_accounts(req: HttpRequest) -> Result { is_fed_cm_enabled()?; is_web_identity_fetch(&req)?; - let (login_status, user_id) = login_status_from_req(&data, &req).await; + let (login_status, user_id) = login_status_from_req(&req).await; if login_status == FedCMLoginStatus::LoggedOut { return Ok(HttpResponse::Unauthorized() .insert_header(FedCMLoginStatus::LoggedOut.as_header_pair()) @@ -55,7 +52,7 @@ pub async fn get_fed_cm_accounts( })); } - let user = User::find_for_fed_cm_validated(&data, user_id).await?; + let user = User::find_for_fed_cm_validated(user_id).await?; let account = FedCMAccount::build(user); let accounts = FedCMAccounts { accounts: vec![account], @@ -96,7 +93,7 @@ pub async fn get_fed_cm_client_meta( )); } - let client = Client::find_maybe_ephemeral(&data, params.client_id).await?; + let client = Client::find_maybe_ephemeral(params.client_id).await?; if !client.enabled { return Err(ErrorResponse::new( ErrorResponseType::WWWAuthenticate("client-disabled".to_string()), @@ -202,13 +199,13 @@ pub async fn get_fed_client_config() -> HttpResponse { )] #[tracing::instrument(level = "debug", skip_all)] #[get("/fed_cm/status")] -pub async fn get_fed_cm_status(req: HttpRequest, data: web::Data) -> HttpResponse { +pub async fn get_fed_cm_status(req: HttpRequest) -> HttpResponse { if is_fed_cm_enabled().is_err() { HttpResponse::Unauthorized() .insert_header(FedCMLoginStatus::LoggedOut.as_header_pair()) .finish() } else { - let (login_status, _) = login_status_from_req(&data, &req).await; + let (login_status, _) = login_status_from_req(&req).await; if login_status == FedCMLoginStatus::LoggedOut { HttpResponse::Unauthorized() .insert_header(FedCMLoginStatus::LoggedOut.as_header_pair()) @@ -246,7 +243,7 @@ pub async fn post_fed_cm_token( is_fed_cm_enabled()?; is_web_identity_fetch(&req)?; - let (login_status, user_id) = login_status_from_req(&data, &req).await; + let (login_status, user_id) = login_status_from_req(&req).await; if login_status == FedCMLoginStatus::LoggedOut { return Ok(HttpResponse::Unauthorized() .insert_header(FedCMLoginStatus::LoggedOut.as_header_pair()) @@ -256,7 +253,7 @@ pub async fn post_fed_cm_token( let payload = payload.into_inner(); // find and check the client - let client = match Client::find_maybe_ephemeral(&data, payload.client_id).await { + let client = match Client::find_maybe_ephemeral(payload.client_id).await { Ok(c) => c, Err(err) => { error!("Error looking up maybe ephemeral client: {:?}", err); @@ -278,7 +275,7 @@ pub async fn post_fed_cm_token( debug!("built origin header for client: {:?}", origin_header.1); // find and check the user - let user = User::find_for_fed_cm_validated(&data, user_id).await?; + let user = User::find_for_fed_cm_validated(user_id).await?; if payload.account_id != user.id { debug!( "payload.account_id != user.id -> {} != {}", @@ -433,17 +430,14 @@ fn client_origin_header( } #[inline(always)] -async fn login_status_from_req( - data: &web::Data, - req: &HttpRequest, -) -> (FedCMLoginStatus, String) { +async fn login_status_from_req(req: &HttpRequest) -> (FedCMLoginStatus, String) { match ApiCookie::from_req(req, COOKIE_SESSION_FED_CM) { None => { debug!("FedCM session cookie not found -> user_id is logged-out",); (FedCMLoginStatus::LoggedOut, String::default()) } Some(sid) => { - let session = match Session::find(data, sid).await { + let session = match Session::find(sid).await { Ok(s) => s, Err(_) => { debug!("FedCM session not found -> user_id is logged-out",); diff --git a/src/api/src/generic.rs b/src/api/src/generic.rs index d9d8e32e5..0331977ed 100644 --- a/src/api/src/generic.rs +++ b/src/api/src/generic.rs @@ -52,11 +52,8 @@ use std::str::FromStr; use tracing::{error, info, warn}; #[get("/")] -pub async fn get_index( - data: web::Data, - req: HttpRequest, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_index(req: HttpRequest) -> Result { + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); let body = IndexHtml::build(&colors, &lang); @@ -122,158 +119,130 @@ pub async fn post_i18n( } #[get("/account")] -pub async fn get_account_html( - data: web::Data, - req: HttpRequest, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_account_html(req: HttpRequest) -> Result { + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); - let providers = AuthProviderTemplate::get_all_json_template(&data).await?; + let providers = AuthProviderTemplate::get_all_json_template().await?; let body = AccountHtml::build(&colors, &lang, providers); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin")] -pub async fn get_admin_html(data: web::Data) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/api_keys")] -pub async fn get_admin_api_keys_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_api_keys_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminApiKeysHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/attributes")] -pub async fn get_admin_attr_html(data: web::Data) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_attr_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminAttributesHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/blacklist")] -pub async fn get_admin_blacklist_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_blacklist_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminBlacklistHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/clients")] -pub async fn get_admin_clients_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_clients_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminClientsHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/config")] -pub async fn get_admin_config_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_config_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminConfigHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/docs")] -pub async fn get_admin_docs_html(data: web::Data) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_docs_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminDocsHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/events")] -pub async fn get_admin_events_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_events_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminUsersHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/groups")] -pub async fn get_admin_groups_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_groups_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminGroupsHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/providers")] -pub async fn get_admin_providers_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_providers_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = ProvidersHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/roles")] -pub async fn get_admin_roles_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_roles_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminRolesHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/scopes")] -pub async fn get_admin_scopes_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_scopes_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminScopesHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/sessions")] -pub async fn get_admin_sessions_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_sessions_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminSessionsHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/admin/users")] -pub async fn get_admin_users_html( - data: web::Data, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_admin_users_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = AdminUsersHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } #[get("/device")] -pub async fn get_device_html( - data: web::Data, - req: HttpRequest, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_device_html(req: HttpRequest) -> Result { + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); let body = DeviceHtml::build(&colors, &lang); @@ -281,8 +250,8 @@ pub async fn get_device_html( } #[get("/fedcm")] -pub async fn get_fed_cm_html(data: web::Data) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_fed_cm_html() -> Result { + let colors = ColorEntity::find_rauthy().await?; let body = FedCMHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } @@ -472,12 +441,9 @@ pub async fn post_password_hash_times( ), )] #[get("/password_policy")] -pub async fn get_password_policy( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_password_policy(principal: ReqPrincipal) -> Result { principal.validate_session_auth()?; - let rules = PasswordPolicy::find(&data).await?; + let rules = PasswordPolicy::find().await?; Ok(HttpResponse::Ok().json(PasswordPolicyResponse::from(rules))) } @@ -499,15 +465,14 @@ pub async fn get_password_policy( )] #[put("/password_policy")] pub async fn put_password_policy( - data: web::Data, principal: ReqPrincipal, req_data: actix_web_validator::Json, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Secrets, AccessRights::Update)?; - let mut rules = PasswordPolicy::find(&data).await?; + let mut rules = PasswordPolicy::find().await?; rules.apply_req(req_data.into_inner()); - rules.save(&data).await?; + rules.save().await?; Ok(HttpResponse::Ok().json(PasswordPolicyResponse::from(rules))) } @@ -555,7 +520,6 @@ pub async fn post_pow() -> Result { )] #[get("/search")] pub async fn get_search( - data: web::Data, params: actix_web_validator::Query, principal: ReqPrincipal, ) -> Result { @@ -564,11 +528,11 @@ pub async fn get_search( let limit = params.limit.unwrap_or(100) as i64; match params.ty { SearchParamsType::Session => { - let res = Session::search(&data, ¶ms.idx, ¶ms.q, limit).await?; + let res = Session::search(¶ms.idx, ¶ms.q, limit).await?; Ok(HttpResponse::Ok().json(res)) } SearchParamsType::User => { - let res = User::search(&data, ¶ms.idx, ¶ms.q, limit).await?; + let res = User::search(¶ms.idx, ¶ms.q, limit).await?; Ok(HttpResponse::Ok().json(res)) } } @@ -586,17 +550,16 @@ pub async fn get_search( )] #[post("/update_language")] pub async fn post_update_language( - data: web::Data, principal: ReqPrincipal, req: HttpRequest, ) -> Result { principal.validate_session_auth()?; let user_id = principal.user_id()?; - let mut user = User::find(&data, user_id.to_string()).await?; + let mut user = User::find(user_id.to_string()).await?; user.language = Language::try_from(&req).unwrap_or_default(); - user.update_language(&data).await?; + user.update_language().await?; Ok(HttpResponse::Ok().finish()) } @@ -613,7 +576,7 @@ pub async fn post_update_language( ), )] #[get("/health")] -pub async fn get_health(data: web::Data) -> impl Responder { +pub async fn get_health() -> impl Responder { if Utc::now().sub(*APP_START).num_seconds() < *HEALTH_CHECK_DELAY_SECS as i64 { info!("Early health check within the HEALTH_CHECK_DELAY_SECS timeframe - returning true"); HttpResponse::Ok().json(HealthResponse { @@ -621,7 +584,7 @@ pub async fn get_health(data: web::Data) -> impl Responder { cache_healthy: true, }) } else { - let db_healthy = is_db_alive(&data.db).await; + let db_healthy = is_db_alive().await; let cache_healthy = DB::client().is_healthy_cache().await.is_ok(); let body = HealthResponse { @@ -718,8 +681,8 @@ pub async fn redirect_v1() -> HttpResponse { ), )] #[get("/version")] -pub async fn get_version(data: web::Data) -> Result { - let resp = match LatestAppVersion::find(&data).await { +pub async fn get_version() -> Result { + let resp = match LatestAppVersion::find().await { Some(latest) => { let update_available = match Version::from_str(RAUTHY_VERSION) { Ok(current) => latest.latest_version > current, diff --git a/src/api/src/groups.rs b/src/api/src/groups.rs index 269ef8cdf..8843295c9 100644 --- a/src/api/src/groups.rs +++ b/src/api/src/groups.rs @@ -2,7 +2,6 @@ use crate::ReqPrincipal; use actix_web::{delete, get, post, put, web, HttpResponse}; use rauthy_api_types::groups::NewGroupRequest; use rauthy_error::ErrorResponse; -use rauthy_models::app_state::AppState; use rauthy_models::entity::api_keys::{AccessGroup, AccessRights}; use rauthy_models::entity::groups::Group; @@ -21,13 +20,10 @@ use rauthy_models::entity::groups::Group; ), )] #[get("/groups")] -pub async fn get_groups( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_groups(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Groups, AccessRights::Read)?; - Group::find_all(&data) + Group::find_all() .await .map(|rls| HttpResponse::Ok().json(rls)) } @@ -49,13 +45,12 @@ pub async fn get_groups( )] #[post("/groups")] pub async fn post_group( - data: web::Data, group_req: actix_web_validator::Json, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Groups, AccessRights::Create)?; - Group::create(&data, group_req.into_inner()) + Group::create(group_req.into_inner()) .await .map(|r| HttpResponse::Ok().json(r)) } @@ -77,14 +72,13 @@ pub async fn post_group( )] #[put("/groups/{id}")] pub async fn put_group( - data: web::Data, id: web::Path, group_req: actix_web_validator::Json, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Groups, AccessRights::Update)?; - Group::update(&data, id.into_inner(), group_req.group.to_owned()) + Group::update(id.into_inner(), group_req.group.to_owned()) .await .map(|g| HttpResponse::Ok().json(g)) } @@ -107,13 +101,12 @@ pub async fn put_group( )] #[delete("/groups/{id}")] pub async fn delete_group( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Groups, AccessRights::Delete)?; - Group::delete(&data, id.into_inner()) + Group::delete(id.into_inner()) .await .map(|_| HttpResponse::Ok().finish()) } diff --git a/src/api/src/oidc.rs b/src/api/src/oidc.rs index c0a3160c0..7d1f0f662 100644 --- a/src/api/src/oidc.rs +++ b/src/api/src/oidc.rs @@ -71,7 +71,7 @@ pub async fn get_authorize( req_data: actix_web_validator::Query, principal: ReqPrincipal, ) -> Result { - let colors = ColorEntity::find(&data, &req_data.client_id) + let colors = ColorEntity::find(&req_data.client_id) .await .unwrap_or_default(); let lang = Language::try_from(&req).unwrap_or_default(); @@ -116,7 +116,7 @@ pub async fn get_authorize( // check if the user needs to do the Webauthn login each time let mut action = FrontendAction::None; if let Ok(mfa_cookie) = WebauthnCookie::parse_validate(&ApiCookie::from_req(&req, COOKIE_MFA)) { - if let Ok(user) = User::find_by_email(&data, mfa_cookie.email.clone()).await { + if let Ok(user) = User::find_by_email(mfa_cookie.email.clone()).await { // we need to check this, because a user could deactivate MFA in another browser or // be deleted while still having existing mfa cookies somewhere else if user.has_webauthn_enabled() { @@ -142,7 +142,7 @@ pub async fn get_authorize( return Ok(ErrorHtml::response(body, status)); } - let auth_providers_json = AuthProviderTemplate::get_all_json_template(&data).await?; + let auth_providers_json = AuthProviderTemplate::get_all_json_template().await?; let tpl_data = Some(format!( "{}\n{}\n{}", client.name.unwrap_or_default(), @@ -180,7 +180,7 @@ pub async fn get_authorize( Session::new(*SESSION_LIFETIME, Some(real_ip_from_req(&req)?)) }; - if let Err(err) = session.save(&data).await { + if let Err(err) = session.save().await { let status = err.status_code(); let body = Error1Html::build(&colors, &lang, status, Some(err.message)); return Ok(ErrorHtml::response(body, status)); @@ -336,25 +336,17 @@ pub async fn post_authorize_refresh( ) .await?; - let auth_step = authorize::post_authorize_refresh( - &data, - session, - client, - header_origin, - req_data.into_inner(), - ) - .await?; + let auth_step = + authorize::post_authorize_refresh(session, client, header_origin, req_data.into_inner()) + .await?; map_auth_step(auth_step, &req).await } #[get("/oidc/callback")] -pub async fn get_callback_html( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_callback_html(principal: ReqPrincipal) -> Result { // TODO can we even be more strict and request session auth here? principal.validate_session_auth_or_init()?; - let colors = ColorEntity::find_rauthy(&data).await?; + let colors = ColorEntity::find_rauthy().await?; let body = CallbackHtml::build(&colors); Ok(HttpResponse::Ok().insert_header(HEADER_HTML).body(body)) } @@ -369,8 +361,8 @@ pub async fn get_callback_html( responses((status = 200, description = "Ok")), )] #[get("/oidc/certs")] -pub async fn get_certs(data: web::Data) -> Result { - let jwks = JWKS::find_pk(&data).await?; +pub async fn get_certs() -> Result { + let jwks = JWKS::find_pk().await?; let res = JWKSCerts::from(jwks); Ok(HttpResponse::Ok() .insert_header(( @@ -390,11 +382,8 @@ pub async fn get_certs(data: web::Data) -> Result, - kid: web::Path, -) -> Result { - let kp = JwkKeyPair::find(&data, kid.into_inner()).await?; +pub async fn get_cert_by_kid(kid: web::Path) -> Result { + let kp = JwkKeyPair::find(kid.into_inner()).await?; let pub_key = JWKSPublicKey::from_key_pair(&kp); Ok(HttpResponse::Ok().json(JWKSPublicKeyCerts::from(pub_key))) } @@ -412,7 +401,6 @@ pub async fn get_cert_by_kid( )] #[post("/oidc/device")] pub async fn post_device_auth( - data: web::Data, req: HttpRequest, payload: actix_web_validator::Form, ) -> HttpResponse { @@ -465,7 +453,7 @@ pub async fn post_device_auth( // find and validate the client let payload = payload.into_inner(); - let client = match Client::find(&data, payload.client_id).await { + let client = match Client::find(payload.client_id).await { Ok(client) => client, Err(_) => { return HttpResponse::NotFound().json(OAuth2ErrorResponse { @@ -656,7 +644,6 @@ pub async fn get_logout( )] #[post("/oidc/logout")] pub async fn post_logout( - data: web::Data, req_data: actix_web_validator::Query, principal: ReqPrincipal, ) -> Result { @@ -667,7 +654,7 @@ pub async fn post_logout( 0, SameSite::None, ); - let cookie = session.invalidate(&data).await?; + let cookie = session.invalidate().await?; if req_data.post_logout_redirect_uri.is_some() { let state = if req_data.state.is_some() { @@ -746,7 +733,7 @@ pub async fn post_session( req: HttpRequest, ) -> Result { let session = Session::new(*SESSION_LIFETIME, real_ip_from_req(&req).ok()); - session.save(&data).await?; + session.save().await?; let cookie = session.client_cookie(); let timeout = OffsetDateTime::from_unix_timestamp(session.last_seen) diff --git a/src/api/src/roles.rs b/src/api/src/roles.rs index 9fe8a8860..4d96512e5 100644 --- a/src/api/src/roles.rs +++ b/src/api/src/roles.rs @@ -2,7 +2,6 @@ use crate::ReqPrincipal; use actix_web::{delete, get, post, put, web, HttpResponse}; use rauthy_api_types::roles::NewRoleRequest; use rauthy_error::ErrorResponse; -use rauthy_models::app_state::AppState; use rauthy_models::entity::api_keys::{AccessGroup, AccessRights}; use rauthy_models::entity::roles::Role; @@ -21,13 +20,10 @@ use rauthy_models::entity::roles::Role; ), )] #[get("/roles")] -pub async fn get_roles( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_roles(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Roles, AccessRights::Read)?; - Role::find_all(&data) + Role::find_all() .await .map(|rls| HttpResponse::Ok().json(rls)) } @@ -49,13 +45,12 @@ pub async fn get_roles( )] #[post("/roles")] pub async fn post_role( - data: web::Data, role_req: actix_web_validator::Json, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Roles, AccessRights::Create)?; - Role::create(&data, role_req.into_inner()) + Role::create(role_req.into_inner()) .await .map(|r| HttpResponse::Ok().json(r)) } @@ -77,14 +72,13 @@ pub async fn post_role( )] #[put("/roles/{id}")] pub async fn put_role( - data: web::Data, id: web::Path, role_req: actix_web_validator::Json, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Roles, AccessRights::Update)?; - Role::update(&data, id.into_inner(), role_req.role.to_owned()) + Role::update(id.into_inner(), role_req.role.to_owned()) .await .map(|r| HttpResponse::Ok().json(r)) } @@ -107,13 +101,12 @@ pub async fn put_role( )] #[delete("/roles/{id}")] pub async fn delete_role( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Roles, AccessRights::Delete)?; - Role::delete(&data, id.as_str()) + Role::delete(id.as_str()) .await .map(|_| HttpResponse::Ok().finish()) } diff --git a/src/api/src/scopes.rs b/src/api/src/scopes.rs index 84b8622bd..7b1bf737f 100644 --- a/src/api/src/scopes.rs +++ b/src/api/src/scopes.rs @@ -21,13 +21,10 @@ use rauthy_models::entity::scopes::Scope; ), )] #[get("/scopes")] -pub async fn get_scopes( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_scopes(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Scopes, AccessRights::Read)?; - Scope::find_all(&data).await.map(|scp| { + Scope::find_all().await.map(|scp| { let res = scp .into_iter() .map(ScopeResponse::from) diff --git a/src/api/src/sessions.rs b/src/api/src/sessions.rs index 2082200bd..392a1e031 100644 --- a/src/api/src/sessions.rs +++ b/src/api/src/sessions.rs @@ -5,7 +5,6 @@ use rauthy_api_types::generic::PaginationParams; use rauthy_api_types::sessions::{SessionResponse, SessionState}; use rauthy_common::constants::SSP_THRESHOLD; use rauthy_error::ErrorResponse; -use rauthy_models::app_state::AppState; use rauthy_models::entity::api_keys::{AccessGroup, AccessRights}; use rauthy_models::entity::continuation_token::ContinuationToken; use rauthy_models::entity::refresh_tokens::RefreshToken; @@ -31,14 +30,13 @@ use rauthy_models::entity::users::User; )] #[get("/sessions")] pub async fn get_sessions( - data: web::Data, principal: ReqPrincipal, params: Query, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Sessions, AccessRights::Read)?; // sessions will be dynamically paginated based on the same setting as users - let user_count = User::count(&data).await?; + let user_count = User::count().await?; if user_count >= *SSP_THRESHOLD as i64 || params.page_size.is_some() { // TODO outsource the setup stuff here or keep it duplicated for better readability? // currently used here and in GET /users @@ -52,8 +50,7 @@ pub async fn get_sessions( }; let (users, continuation_token) = - Session::find_paginated(&data, continuation_token, page_size, offset, backwards) - .await?; + Session::find_paginated(continuation_token, page_size, offset, backwards).await?; let x_page_count = (user_count as f64 / page_size as f64).ceil() as u32; if let Some(token) = continuation_token { @@ -71,7 +68,7 @@ pub async fn get_sessions( .json(users)) } } else { - let sessions = Session::find_all(&data).await?; + let sessions = Session::find_all().await?; let mut resp = Vec::with_capacity(sessions.len()); for s in &sessions { @@ -110,14 +107,11 @@ pub async fn get_sessions( ), )] #[delete("/sessions")] -pub async fn delete_sessions( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn delete_sessions(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Sessions, AccessRights::Delete)?; - Session::invalidate_all(&data).await?; - RefreshToken::invalidate_all(&data).await?; + Session::invalidate_all().await?; + RefreshToken::invalidate_all().await?; Ok(HttpResponse::Ok().finish()) } @@ -140,15 +134,14 @@ pub async fn delete_sessions( )] #[delete("/sessions/{user_id}")] pub async fn delete_sessions_for_user( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Sessions, AccessRights::Delete)?; let uid = path.into_inner(); - Session::invalidate_for_user(&data, &uid).await?; - RefreshToken::invalidate_for_user(&data, &uid).await?; + Session::invalidate_for_user(&uid).await?; + RefreshToken::invalidate_for_user(&uid).await?; Ok(HttpResponse::Ok().finish()) } diff --git a/src/api/src/users.rs b/src/api/src/users.rs index aa62c02f9..cc30d0e68 100644 --- a/src/api/src/users.rs +++ b/src/api/src/users.rs @@ -64,13 +64,12 @@ use tracing::{error, warn}; )] #[get("/users")] pub async fn get_users( - data: web::Data, principal: ReqPrincipal, params: Query, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Users, AccessRights::Read)?; - let user_count = User::count(&data).await?; + let user_count = User::count().await?; if user_count >= *SSP_THRESHOLD as i64 || params.page_size.is_some() { let page_size = params.page_size.unwrap_or(15) as i64; @@ -83,7 +82,7 @@ pub async fn get_users( }; let (users, continuation_token) = - User::find_paginated(&data, continuation_token, page_size, offset, backwards).await?; + User::find_paginated(continuation_token, page_size, offset, backwards).await?; let x_page_count = (user_count as f64 / page_size as f64).ceil() as u32; if let Some(token) = continuation_token { @@ -101,7 +100,7 @@ pub async fn get_users( .json(users)) } } else { - let users = User::find_all_simple(&data).await?; + let users = User::find_all_simple().await?; Ok(HttpResponse::Ok() .insert_header(("x-user-count", user_count)) .json(users)) @@ -165,13 +164,10 @@ pub async fn post_users( ), )] #[get("/users/attr")] -pub async fn get_cust_attr( - data: web::Data, - principal: ReqPrincipal, -) -> Result { +pub async fn get_cust_attr(principal: ReqPrincipal) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Read)?; - UserAttrConfigEntity::find_all(&data).await.map(|values| { + UserAttrConfigEntity::find_all().await.map(|values| { HttpResponse::Ok().json(UserAttrConfigResponse { values: values.into_iter().map(|v| v.into()).collect(), }) @@ -191,14 +187,13 @@ pub async fn get_cust_attr( )] #[post("/users/attr")] pub async fn post_cust_attr( - data: web::Data, principal: ReqPrincipal, req_data: Json, ) -> Result { principal .validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Create)?; - UserAttrConfigEntity::create(&data, req_data.into_inner()) + UserAttrConfigEntity::create(req_data.into_inner()) .await .map(|attr| HttpResponse::Ok().json(attr)) } @@ -218,7 +213,6 @@ pub async fn post_cust_attr( )] #[put("/users/attr/{name}")] pub async fn put_cust_attr( - data: web::Data, path: web::Path, principal: ReqPrincipal, req_data: Json, @@ -226,7 +220,7 @@ pub async fn put_cust_attr( principal .validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Update)?; - UserAttrConfigEntity::update(&data, path.into_inner(), req_data.into_inner()) + UserAttrConfigEntity::update(path.into_inner(), req_data.into_inner()) .await .map(|a| HttpResponse::Ok().json(a)) } @@ -244,14 +238,13 @@ pub async fn put_cust_attr( )] #[delete("/users/attr/{name}")] pub async fn delete_cust_attr( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal .validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Delete)?; - UserAttrConfigEntity::delete(&data, path.into_inner()).await?; + UserAttrConfigEntity::delete(path.into_inner()).await?; Ok(HttpResponse::Ok().finish()) } @@ -267,11 +260,8 @@ pub async fn delete_cust_attr( ), )] #[get("/users/register")] -pub async fn get_users_register( - data: web::Data, - req: HttpRequest, -) -> Result { - let colors = ColorEntity::find_rauthy(&data).await?; +pub async fn get_users_register(req: HttpRequest) -> Result { + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); if !*OPEN_USER_REG { @@ -369,7 +359,6 @@ pub async fn post_users_register( )] #[get("/users/{id}")] pub async fn get_user_by_id( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { @@ -384,8 +373,8 @@ pub async fn get_user_by_id( principal.is_user(&id)?; } - let user = User::find(&data, id).await?; - let values = UserValues::find(&data, &user.id).await?; + let user = User::find(id).await?; + let values = UserValues::find(&user.id).await?; Ok(HttpResponse::Ok().json(user.into_response(values))) } @@ -402,13 +391,12 @@ pub async fn get_user_by_id( )] #[get("/users/{id}/attr")] pub async fn get_user_attr( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Read)?; - let values = UserAttrValueEntity::find_for_user(&data, &path.into_inner()) + let values = UserAttrValueEntity::find_for_user(&path.into_inner()) .await? .drain(..) .map(UserAttrValueResponse::from) @@ -430,7 +418,6 @@ pub async fn get_user_attr( )] #[put("/users/{id}/attr")] pub async fn put_user_attr( - data: web::Data, path: web::Path, principal: ReqPrincipal, req_data: Json, @@ -438,12 +425,11 @@ pub async fn put_user_attr( principal .validate_api_key_or_admin_session(AccessGroup::UserAttributes, AccessRights::Update)?; - let values = - UserAttrValueEntity::update_for_user(&data, &path.into_inner(), req_data.into_inner()) - .await? - .drain(..) - .map(UserAttrValueResponse::from) - .collect::>(); + let values = UserAttrValueEntity::update_for_user(&path.into_inner(), req_data.into_inner()) + .await? + .drain(..) + .map(UserAttrValueResponse::from) + .collect::>(); Ok(HttpResponse::Ok().json(UserAttrValuesResponse { values })) } @@ -460,14 +446,13 @@ pub async fn put_user_attr( )] #[get("/users/{id}/devices")] pub async fn get_user_devices( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { let user_id = path.into_inner(); principal.validate_user_or_admin(&user_id)?; - let resp = DeviceEntity::find_for_user(&data, &user_id) + let resp = DeviceEntity::find_for_user(&user_id) .await? .into_iter() .map(DeviceResponse::from) @@ -489,7 +474,6 @@ pub async fn get_user_devices( )] #[put("/users/{id}/devices")] pub async fn put_user_device_name( - data: web::Data, path: web::Path, principal: ReqPrincipal, payload: actix_web_validator::Json, @@ -499,7 +483,7 @@ pub async fn put_user_device_name( let payload = payload.into_inner(); if let Some(name) = &payload.name { - DeviceEntity::update_name(&data, &payload.device_id, &user_id, name).await?; + DeviceEntity::update_name(&payload.device_id, &user_id, name).await?; } Ok(HttpResponse::Ok().finish()) @@ -520,7 +504,6 @@ pub async fn put_user_device_name( )] #[delete("/users/{id}/devices")] pub async fn delete_user_device( - data: web::Data, path: web::Path, principal: ReqPrincipal, payload: actix_web_validator::Json, @@ -529,7 +512,7 @@ pub async fn delete_user_device( principal.validate_user_or_admin(&user_id)?; let payload = payload.into_inner(); - let device = DeviceEntity::find(&data, &payload.device_id).await?; + let device = DeviceEntity::find(&payload.device_id).await?; if device.user_id.as_deref() != Some(&user_id) { return Err(ErrorResponse::new( ErrorResponseType::Forbidden, @@ -537,7 +520,7 @@ pub async fn delete_user_device( )); } - DeviceEntity::revoke_refresh_tokens(&data, &payload.device_id).await?; + DeviceEntity::revoke_refresh_tokens(&payload.device_id).await?; Ok(HttpResponse::Ok().finish()) } @@ -566,7 +549,7 @@ pub async fn get_user_email_confirm( match User::confirm_email_address(&data, req, user_id, confirm_id).await { Ok(html) => HttpResponse::Ok().insert_header(HEADER_HTML).body(html), Err(err) => { - let colors = ColorEntity::find_rauthy(&data).await.unwrap_or_default(); + let colors = ColorEntity::find_rauthy().await.unwrap_or_default(); let status = err.status_code(); let body = Error3Html::build(&colors, &lang, status, Some(err.message)); ErrorHtml::response(body, status) @@ -590,7 +573,6 @@ pub async fn get_user_email_confirm( )] #[get("/users/{id}/reset/{reset_id}")] pub async fn get_user_password_reset( - data: web::Data, path: web::Path<(String, String)>, req: HttpRequest, ) -> HttpResponse { @@ -603,13 +585,13 @@ pub async fn get_user_password_reset( .unwrap_or("text/html"); let no_html = accept == "application/json"; - match password_reset::handle_get_pwd_reset(&data, req, user_id, reset_id, no_html).await { + match password_reset::handle_get_pwd_reset(req, user_id, reset_id, no_html).await { Ok((content, cookie)) => { if no_html { - let password_policy = match PasswordPolicy::find(&data).await { + let password_policy = match PasswordPolicy::find().await { Ok(policy) => PasswordPolicyResponse::from(policy), Err(err) => { - let colors = ColorEntity::find_rauthy(&data).await.unwrap_or_default(); + let colors = ColorEntity::find_rauthy().await.unwrap_or_default(); let status = err.status_code(); let body = Error3Html::build(&colors, &lang, status, Some(err.message)); return ErrorHtml::response(body, status); @@ -631,7 +613,7 @@ pub async fn get_user_password_reset( } } Err(err) => { - let colors = ColorEntity::find_rauthy(&data).await.unwrap_or_default(); + let colors = ColorEntity::find_rauthy().await.unwrap_or_default(); let status = err.status_code(); let body = Error3Html::build(&colors, &lang, status, Some(err.message)); ErrorHtml::response(body, status) @@ -706,7 +688,6 @@ pub async fn put_user_password_reset( )] #[get("/users/{id}/webauthn")] pub async fn get_user_webauthn_passkeys( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { @@ -721,7 +702,7 @@ pub async fn get_user_webauthn_passkeys( principal.is_user(&id)?; } - let pks = PasskeyEntity::find_for_user(&data, &id) + let pks = PasskeyEntity::find_for_user(&id) .await? .into_iter() .map(PasskeyResponse::from) @@ -853,7 +834,6 @@ pub async fn post_webauthn_auth_finish( )] #[delete("/users/{id}/webauthn/delete/{name}")] pub async fn delete_webauthn( - data: web::Data, path: web::Path<(String, String)>, principal: ReqPrincipal, ) -> Result { @@ -876,11 +856,11 @@ pub async fn delete_webauthn( warn!("Passkey delete from admin for user {} for key {}", id, name); } - PasskeyEntity::delete(&data, id, name).await?; + PasskeyEntity::delete(id, name).await?; // // if we delete a passkey, we must check if this is the last existing one for the user // let pks = PasskeyEntity::find_for_user(&data, &id).await?; // - // let mut txn = data.db.begin().await?; + // let mut txn = DB::txn().await?; // // PasskeyEntity::delete_by_id_name(&id, &name, &mut txn).await?; // if pks.len() < 2 { @@ -1039,8 +1019,8 @@ pub async fn get_user_webid( } let id = id.into_inner(); - let webid = WebId::find(&data, id).await?; - let user = User::find(&data, webid.user_id.clone()).await?; + let webid = WebId::find(id).await?; + let user = User::find(webid.user_id.clone()).await?; let resp = WebIdResponse { webid: webid.into(), @@ -1069,7 +1049,6 @@ pub async fn get_user_webid( )] #[get("/users/{id}/webid/data")] pub async fn get_user_webid_data( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { @@ -1088,7 +1067,7 @@ pub async fn get_user_webid_data( } // request is valid -> either the user requests own data, or it is an admin - let webid = WebId::find(&data, id).await?; + let webid = WebId::find(id).await?; Ok(HttpResponse::Ok().json(webid)) } @@ -1108,7 +1087,6 @@ pub async fn get_user_webid_data( )] #[put("/users/{id}/webid/data")] pub async fn put_user_webid_data( - data: web::Data, id: web::Path, principal: ReqPrincipal, payload: Json, @@ -1132,7 +1110,7 @@ pub async fn put_user_webid_data( ) })?; - WebId::upsert(&data, web_id).await?; + WebId::upsert(web_id).await?; Ok(HttpResponse::Ok().finish()) } @@ -1168,7 +1146,7 @@ pub async fn post_user_password_request_reset( principal.validate_session_auth_or_init()?; let payload = payload.into_inner(); - match User::find_by_email(&data, payload.email).await { + match User::find_by_email(payload.email).await { Ok(user) => user .request_password_reset(&data, req, payload.redirect_uri) .await @@ -1196,14 +1174,13 @@ pub async fn post_user_password_request_reset( )] #[get("/users/email/{email}")] pub async fn get_user_by_email( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Users, AccessRights::Read)?; - let user = User::find_by_email(&data, path.into_inner()).await?; - let values = UserValues::find(&data, &user.id).await?; + let user = User::find_by_email(path.into_inner()).await?; + let values = UserValues::find(&user.id).await?; Ok(HttpResponse::Ok().json(user.into_response(values))) } @@ -1302,7 +1279,6 @@ pub async fn put_user_self( )] #[post("/users/{id}/self/convert_passkey")] pub async fn post_user_self_convert_passkey( - data: web::Data, id: web::Path, principal: ReqPrincipal, ) -> Result { @@ -1312,7 +1288,7 @@ pub async fn post_user_self_convert_passkey( let id = id.into_inner(); principal.is_user(&id)?; - User::convert_to_passkey(&data, id).await?; + User::convert_to_passkey(id).await?; Ok(HttpResponse::Ok().finish()) } @@ -1332,13 +1308,12 @@ pub async fn post_user_self_convert_passkey( )] #[delete("/users/{id}")] pub async fn delete_user_by_id( - data: web::Data, path: web::Path, principal: ReqPrincipal, ) -> Result { principal.validate_api_key_or_admin_session(AccessGroup::Users, AccessRights::Delete)?; - let user = User::find(&data, path.into_inner()).await?; - user.delete(&data).await?; + let user = User::find(path.into_inner()).await?; + user.delete().await?; Ok(HttpResponse::NoContent().finish()) } diff --git a/src/bin/src/dummy_data.rs b/src/bin/src/dummy_data.rs index 62beadedb..413cf3d8f 100644 --- a/src/bin/src/dummy_data.rs +++ b/src/bin/src/dummy_data.rs @@ -1,14 +1,10 @@ -use actix_web::web; use rauthy_common::utils::get_rand; use rauthy_error::ErrorResponse; -use rauthy_models::{app_state::AppState, entity::users::User, language::Language}; +use rauthy_models::{entity::users::User, language::Language}; use std::time::Duration; use tracing::info; -pub async fn insert_dummy_data( - app_state: web::Data, - amount: u32, -) -> Result<(), ErrorResponse> { +pub async fn insert_dummy_data(amount: u32) -> Result<(), ErrorResponse> { tokio::time::sleep(Duration::from_secs(1)).await; info!( r#" @@ -42,7 +38,7 @@ Will go on in 10 seconds... user_expires: None, ..Default::default() }; - User::insert(&app_state, user).await?; + User::insert(user).await?; } Ok(()) diff --git a/src/bin/src/main.rs b/src/bin/src/main.rs index d1a73837b..32b676608 100644 --- a/src/bin/src/main.rs +++ b/src/bin/src/main.rs @@ -175,7 +175,6 @@ TODO add link to the book after migration tx_events_router, rx_events_router, rx_events, - app_state.db.clone(), )); // spawn password hash limiter @@ -188,10 +187,7 @@ TODO add link to the book after migration // spawn health watcher debug!("Starting health watch"); - tokio::spawn(watch_health( - app_state.db.clone(), - app_state.tx_events.clone(), - )); + tokio::spawn(watch_health(app_state.tx_events.clone())); // schedulers match env::var("SCHED_DISABLE") @@ -222,10 +218,7 @@ TODO add link to the book after migration } else { 100_000 }; - tokio::spawn(crate::dummy_data::insert_dummy_data( - app_state.clone(), - amount, - )); + tokio::spawn(crate::dummy_data::insert_dummy_data(amount)); } actix.join().unwrap().unwrap(); diff --git a/src/middlewares/src/principal.rs b/src/middlewares/src/principal.rs index 1c9e38c17..9404e8d0e 100644 --- a/src/middlewares/src/principal.rs +++ b/src/middlewares/src/principal.rs @@ -66,7 +66,7 @@ where .app_data::>() .expect("Error getting AppData inside session middleware"); - principal.api_key = get_api_key_from_headers(&req, data).await?; + principal.api_key = get_api_key_from_headers(&req).await?; if let Some(s) = get_session_from_cookie(&req, data).await? { principal.roles = s.roles_as_vec().unwrap_or_default(); principal.session = Some(s); @@ -89,10 +89,7 @@ where } #[inline(always)] -async fn get_api_key_from_headers( - req: &ServiceRequest, - data: &web::Data, -) -> Result, ErrorResponse> { +async fn get_api_key_from_headers(req: &ServiceRequest) -> Result, ErrorResponse> { let headers = req.headers(); let auth_header = if let Some(Ok(header)) = headers.get("Authorization").map(|h| h.to_str()) { header @@ -114,7 +111,7 @@ async fn get_api_key_from_headers( }; if let Some(api_key_value) = api_key_value { - ApiKeyEntity::api_key_from_token_validated(data, api_key_value) + ApiKeyEntity::api_key_from_token_validated(api_key_value) .await .map(Some) } else { @@ -134,7 +131,7 @@ async fn get_session_from_cookie( Some(session_id) => session_id, }; - match Session::find(data, session_id).await { + match Session::find(session_id).await { Ok(mut session) => { let remote_ip = if *SESSION_VALIDATE_IP { real_ip_from_svc_req(req).ok() @@ -146,7 +143,7 @@ async fn get_session_from_cookie( // only update the last_seen, if it is older than 10 seconds if session.last_seen < now - 10 { session.last_seen = now; - session.save(data).await?; + session.save().await?; } if req.method() != http::Method::GET && !is_path_csrf_exception(req.path()) { diff --git a/src/models/src/app_state.rs b/src/models/src/app_state.rs index 86958a0ed..4ed9df1b3 100644 --- a/src/models/src/app_state.rs +++ b/src/models/src/app_state.rs @@ -3,18 +3,11 @@ use crate::events::event::Event; use crate::events::ip_blacklist_handler::IpBlacklistReq; use crate::events::listener::EventRouterMsg; use crate::ListenScheme; -use anyhow::Context; -use rauthy_common::constants::{DATABASE_URL, DB_TYPE, PROXY_MODE}; -use rauthy_common::DbType; -use sqlx::pool::PoolOptions; -use sqlx::ConnectOptions; +use rauthy_common::constants::PROXY_MODE; use std::env; -use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use tokio::sync::mpsc; -use tracing::log::LevelFilter; -use tracing::{debug, error, info}; +use tracing::{debug, info}; use webauthn_rs::prelude::Url; use webauthn_rs::Webauthn; @@ -23,7 +16,6 @@ pub type DbTxn<'a> = sqlx::Transaction<'a, sqlx::Postgres>; #[derive(Debug, Clone)] pub struct AppState { - pub db: DbPool, pub public_url: String, pub argon2_params: Argon2Params, pub issuer: String, @@ -163,12 +155,7 @@ impl AppState { .rp_name(&rp_name); let webauthn = Arc::new(builder.build().expect("Invalid configuration")); - debug!("Creating DB Pool now"); - let db = Self::new_db_pool().await?; - debug!("DB Pool created"); - Ok(Self { - db, public_url, argon2_params, issuer, @@ -187,80 +174,80 @@ impl AppState { }) } - pub async fn new_db_pool() -> anyhow::Result { - let db_max_conn = env::var("DATABASE_MAX_CONN") - .unwrap_or_else(|_| String::from("5")) - .parse::() - .expect("Error parsing DATABASE_MAX_CONN to u32"); - - let pool = { - if *DB_TYPE == DbType::Sqlite { - debug!("DATABASE_URL: {}", *DATABASE_URL); - - let msg = r#" - You are trying to connect to a SQLite instance with the 'Postgres' - version of Rauthy. You need to either change to a SQLite database or use the '*-lite' - container image of Rauthy."#; - error!("{msg}"); - panic!("{msg}"); - } - - info!("Trying to connect to Postgres instance"); - let pool = Self::connect_postgres(&DATABASE_URL, db_max_conn).await?; - info!("Database Connection established"); - - debug!("Migrating data from ../../migrations/postgres"); - sqlx::migrate!("../../migrations/postgres") - .run(&pool) - .await?; + // pub async fn new_db_pool() -> anyhow::Result { + // let db_max_conn = env::var("DATABASE_MAX_CONN") + // .unwrap_or_else(|_| String::from("5")) + // .parse::() + // .expect("Error parsing DATABASE_MAX_CONN to u32"); + // + // let pool = { + // if *DB_TYPE == DbType::Sqlite { + // debug!("DATABASE_URL: {}", *DATABASE_URL); + // + // let msg = r#" + // You are trying to connect to a SQLite instance with the 'Postgres' + // version of Rauthy. You need to either change to a SQLite database or use the '*-lite' + // container image of Rauthy."#; + // error!("{msg}"); + // panic!("{msg}"); + // } + // + // info!("Trying to connect to Postgres instance"); + // let pool = Self::connect_postgres(&DATABASE_URL, db_max_conn).await?; + // info!("Database Connection established"); + // + // debug!("Migrating data from ../../migrations/postgres"); + // sqlx::migrate!("../../migrations/postgres") + // .run(&pool) + // .await?; + // + // pool + // }; + // + // Ok(pool) + // } - pool - }; - - Ok(pool) - } - - pub async fn connect_sqlite( - addr: &str, - max_conn: u32, - // migration_only: bool, - ) -> anyhow::Result { - let opts = sqlx::sqlite::SqliteConnectOptions::from_str(addr)? - .create_if_missing(true) - .busy_timeout(Duration::from_millis(100)) - .foreign_keys(true) - .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) - .synchronous(sqlx::sqlite::SqliteSynchronous::Normal) - .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); - - let pool = PoolOptions::new() - .min_connections(2) - .max_connections(max_conn) - .acquire_timeout(Duration::from_secs(10)) - .connect_with(opts) - .await - .context("failed to connect to sqlite")?; - - info!("Database Connection Pool created successfully"); - - Ok(pool) - } - - pub async fn connect_postgres(addr: &str, max_conn: u32) -> anyhow::Result { - let opts = sqlx::postgres::PgConnectOptions::from_str(addr)? - .log_slow_statements(LevelFilter::Debug, Duration::from_secs(3)); - let pool = PoolOptions::new() - .min_connections(2) - .max_connections(max_conn) - .acquire_timeout(Duration::from_secs(10)) - .connect_with(opts) - .await - .context("failed to connect to postgres")?; - - info!("Database Connection Pool created successfully"); - - Ok(pool) - } + // pub async fn connect_sqlite( + // addr: &str, + // max_conn: u32, + // // migration_only: bool, + // ) -> anyhow::Result { + // let opts = sqlx::sqlite::SqliteConnectOptions::from_str(addr)? + // .create_if_missing(true) + // .busy_timeout(Duration::from_millis(100)) + // .foreign_keys(true) + // .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) + // .synchronous(sqlx::sqlite::SqliteSynchronous::Normal) + // .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); + // + // let pool = PoolOptions::new() + // .min_connections(2) + // .max_connections(max_conn) + // .acquire_timeout(Duration::from_secs(10)) + // .connect_with(opts) + // .await + // .context("failed to connect to sqlite")?; + // + // info!("Database Connection Pool created successfully"); + // + // Ok(pool) + // } + // + // pub async fn connect_postgres(addr: &str, max_conn: u32) -> anyhow::Result { + // let opts = sqlx::postgres::PgConnectOptions::from_str(addr)? + // .log_slow_statements(LevelFilter::Debug, Duration::from_secs(3)); + // let pool = PoolOptions::new() + // .min_connections(2) + // .max_connections(max_conn) + // .acquire_timeout(Duration::from_secs(10)) + // .connect_with(opts) + // .await + // .context("failed to connect to postgres")?; + // + // info!("Database Connection Pool created successfully"); + // + // Ok(pool) + // } } /// Holds the `argon2::Params` for the application. diff --git a/src/models/src/entity/api_keys.rs b/src/models/src/entity/api_keys.rs index 49b9d7baa..e304a8ca4 100644 --- a/src/models/src/entity/api_keys.rs +++ b/src/models/src/entity/api_keys.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; -use actix_web::web; use chrono::Utc; use cryptr::{EncKeys, EncValue}; use hiqlite::{params, Param}; @@ -79,14 +77,14 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, Ok(secret_fmt) } - pub async fn delete(data: &web::Data, name: &str) -> Result<(), ErrorResponse> { + pub async fn delete(name: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM api_keys WHERE name = $1", params!(name)) .await?; } else { query!("DELETE FROM api_keys WHERE name = $1", name) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -94,39 +92,36 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, Ok(()) } - pub async fn find(data: &web::Data, name: &str) -> Result { + pub async fn find(name: &str) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM api_keys WHERE name = $1", params!(name)) .await? } else { query_as!(Self, "SELECT * FROM api_keys WHERE name = $1", name) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(res) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as("SELECT * FROM api_keys", params!()) .await? } else { query_as!(Self, "SELECT * FROM api_keys") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn generate_secret( - data: &web::Data, - name: &str, - ) -> Result { - let entity = ApiKeyEntity::find(data, name).await?; + pub async fn generate_secret(name: &str) -> Result { + let entity = ApiKeyEntity::find(name).await?; let api_key = entity.into_api_key()?; // generate a new secret @@ -160,7 +155,7 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, access_enc, name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -172,12 +167,11 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, /// Updates the API Key. Does NOT update the secret in any way! pub async fn update( - data: &web::Data, name: &str, expires: Option, access: Vec, ) -> Result<(), ErrorResponse> { - let entity = ApiKeyEntity::find(data, name).await?; + let entity = ApiKeyEntity::find(name).await?; let api_key = entity.into_api_key()?; let secret_enc = EncValue::encrypt(&api_key.secret)?.into_bytes().to_vec(); @@ -215,7 +209,7 @@ WHERE name = $5"#, access_enc, name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -224,7 +218,7 @@ WHERE name = $5"#, Ok(()) } - pub async fn save(self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(self) -> Result<(), ErrorResponse> { let name = self.name.clone(); if is_hiqlite() { @@ -255,7 +249,7 @@ WHERE name = $5"#, self.access, self.name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -279,10 +273,7 @@ impl ApiKeyEntity { } #[inline(always)] - pub async fn api_key_from_token_validated( - data: &web::Data, - token: &str, - ) -> Result { + pub async fn api_key_from_token_validated(token: &str) -> Result { let (name, secret) = token.split_once('$').ok_or_else(|| { ErrorResponse::new(ErrorResponseType::BadRequest, "Malformed API-Key") })?; @@ -292,7 +283,7 @@ impl ApiKeyEntity { let api_key = if let Some(key) = client.get(Cache::App, &idx).await? { key } else { - let key = Self::find(data, name).await?.into_api_key()?; + let key = Self::find(name).await?.into_api_key()?; client.put(Cache::App, idx, &key, CACHE_TTL_APP).await?; key }; diff --git a/src/models/src/entity/app_version.rs b/src/models/src/entity/app_version.rs index 245f2db02..da52c8ff5 100644 --- a/src/models/src/entity/app_version.rs +++ b/src/models/src/entity/app_version.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; -use actix_web::web; use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::constants::{CACHE_TTL_APP, IDX_APP_VERSION, RAUTHY_VERSION}; @@ -21,7 +19,7 @@ pub struct LatestAppVersion { } impl LatestAppVersion { - pub async fn find(app_state: &web::Data) -> Option { + pub async fn find() -> Option { let client = DB::client(); if let Ok(Some(slf)) = client.get(Cache::App, IDX_APP_VERSION).await { @@ -38,7 +36,7 @@ impl LatestAppVersion { .ok() } else { query!("select data from config where id = 'latest_version'") - .fetch_optional(&app_state.db) + .fetch_optional(DB::conn()) .await .ok()? .map(|r| { @@ -78,7 +76,6 @@ impl LatestAppVersion { } pub async fn upsert( - app_state: &web::Data, latest_version: semver::Version, release_url: String, ) -> Result<(), ErrorResponse> { @@ -105,7 +102,7 @@ INSERT INTO config (id, data) VALUES ('latest_version', $1) ON CONFLICT(id) DO UPDATE SET data = $1"#, data ) - .execute(&app_state.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index 020441997..6c0f29889 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -223,16 +223,13 @@ impl<'r> From> for AuthProvider { } impl AuthProvider { - pub async fn create( - data: &web::Data, - payload: ProviderRequest, - ) -> Result { + pub async fn create(payload: ProviderRequest) -> Result { let mut slf = Self::try_from_id_req(new_store_id(), payload)?; let typ = slf.typ.as_str(); slf = if is_hiqlite() { DB::client() - .execute_returning_map( + .execute_returning_map_one( r#" INSERT INTO auth_providers (id, name, enabled, typ, issuer, authorization_endpoint, token_endpoint, @@ -263,7 +260,6 @@ RETURNING *"#, ), ) .await? - .remove(0)? } else { query!( r#" @@ -292,13 +288,13 @@ VALUES slf.use_pkce, slf.root_pem, ) - .execute(&data.db) + .execute(DB::conn()) .await?; slf }; - Self::invalidate_cache_all(data).await?; + Self::invalidate_cache_all().await?; DB::client() .put(Cache::App, Self::cache_idx(&slf.id), &slf, CACHE_TTL_APP) @@ -307,7 +303,7 @@ VALUES Ok(slf) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, Self::cache_idx(id)).await? { return Ok(slf); @@ -319,7 +315,7 @@ VALUES .await? } else { query_as!(Self, "SELECT * FROM auth_providers WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -330,7 +326,7 @@ VALUES Ok(slf) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let client = DB::client(); if let Some(res) = client.get(Cache::App, Self::cache_idx("all")).await? { return Ok(res); @@ -342,7 +338,7 @@ VALUES .await? } else { query_as!(Self, "SELECT * FROM auth_providers") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -355,7 +351,6 @@ VALUES } pub async fn find_linked_users( - data: &web::Data, id: &str, ) -> Result, ErrorResponse> { let users = if is_hiqlite() { @@ -371,39 +366,35 @@ VALUES "SELECT id, email FROM users WHERE auth_provider_id = $1", id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(users) } - pub async fn delete(data: &web::Data, id: &str) -> Result<(), ErrorResponse> { + pub async fn delete(id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM auth_providers WHERE id = $1", params!(id)) .await?; } else { query!("DELETE FROM auth_providers WHERE id = $1", id) - .execute(&data.db) + .execute(DB::conn()) .await?; } - Self::invalidate_cache_all(data).await?; + Self::invalidate_cache_all().await?; DB::client().delete(Cache::App, Self::cache_idx(id)).await?; Ok(()) } - pub async fn update( - data: &web::Data, - id: String, - payload: ProviderRequest, - ) -> Result<(), ErrorResponse> { - Self::try_from_id_req(id, payload)?.save(data).await + pub async fn update(id: String, payload: ProviderRequest) -> Result<(), ErrorResponse> { + Self::try_from_id_req(id, payload)?.save().await } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { let typ = self.typ.as_str(); if is_hiqlite() { @@ -466,11 +457,11 @@ WHERE id = $18"#, self.root_pem, self.id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } - Self::invalidate_cache_all(data).await?; + Self::invalidate_cache_all().await?; DB::client() .put(Cache::App, Self::cache_idx(&self.id), self, CACHE_TTL_APP) .await?; @@ -555,14 +546,14 @@ impl AuthProvider { }) } - async fn invalidate_cache_all(data: &web::Data) -> Result<(), ErrorResponse> { + async fn invalidate_cache_all() -> Result<(), ErrorResponse> { DB::client() .delete(Cache::App, Self::cache_idx("all")) .await?; // Directly update the template cache preemptively. // This is needed all the time anyway. - AuthProviderTemplate::update_cache(data).await?; + AuthProviderTemplate::update_cache().await?; Ok(()) } @@ -768,12 +759,11 @@ impl AuthProviderCallback { impl AuthProviderCallback { /// returns (encrypted cookie, xsrf token, location header, optional allowed origins) - pub async fn login_start( - data: &web::Data, + pub async fn login_start<'a>( payload: ProviderLoginRequest, - ) -> Result<(Cookie, String, HeaderValue), ErrorResponse> { - let provider = AuthProvider::find(data, &payload.provider_id).await?; - let client = Client::find(data, payload.client_id).await?; + ) -> Result<(Cookie<'a>, String, HeaderValue), ErrorResponse> { + let provider = AuthProvider::find(&payload.provider_id).await?; + let client = Client::find(payload.client_id).await?; let slf = Self { callback_id: secure_random_alnum(32), @@ -886,7 +876,7 @@ impl AuthProviderCallback { debug!("callback pkce verifier is valid"); // request is valid -> fetch token for the user - let provider = AuthProvider::find(data, &slf.provider_id).await?; + let provider = AuthProvider::find(&slf.provider_id).await?; let client = AuthProvider::build_client( provider.allow_insecure_requests, provider.root_pem.as_deref(), @@ -951,9 +941,7 @@ impl AuthProviderCallback { if let Some(id_token) = ts.id_token { let claims_bytes = AuthProviderIdClaims::self_as_bytes_from_token(&id_token)?; let claims = AuthProviderIdClaims::try_from(claims_bytes.as_slice())?; - claims - .validate_update_user(data, &provider, &link_cookie) - .await? + claims.validate_update_user(&provider, &link_cookie).await? } else if let Some(access_token) = ts.access_token { // the id_token only exists, if we actually have an OIDC provider. // If we only get an access token, we need to do another request to the @@ -970,9 +958,7 @@ impl AuthProviderCallback { let res_bytes = res.bytes().await?; let claims = AuthProviderIdClaims::try_from(res_bytes.as_bytes())?; - claims - .validate_update_user(data, &provider, &link_cookie) - .await? + claims.validate_update_user(&provider, &link_cookie).await? } else { let err = "Neither `access_token` nor `id_token` existed"; error!("{}", err); @@ -1003,7 +989,7 @@ impl AuthProviderCallback { } // validate client values - let client = Client::find_maybe_ephemeral(data, slf.req_client_id).await?; + let client = Client::find_maybe_ephemeral(slf.req_client_id).await?; let force_mfa = client.force_mfa(); if force_mfa { if provider_mfa_login == ProviderMfaLogin::No && !user.has_webauthn_enabled() { @@ -1012,7 +998,7 @@ impl AuthProviderCallback { "MFA is required for this client", )); } - session.set_mfa(data, true).await?; + session.set_mfa(true).await?; } client.validate_redirect_uri(&slf.req_redirect_uri)?; client.validate_code_challenge(&slf.req_code_challenge, &slf.req_code_challenge_method)?; @@ -1096,15 +1082,13 @@ pub struct AuthProviderTemplate { } impl AuthProviderTemplate { - pub async fn get_all_json_template( - data: &web::Data, - ) -> Result, ErrorResponse> { + pub async fn get_all_json_template() -> Result, ErrorResponse> { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_AUTH_PROVIDER_TEMPLATE).await? { return Ok(slf); } - let providers = AuthProvider::find_all(data) + let providers = AuthProvider::find_all() .await? .into_iter() // We don't want to even show disabled providers @@ -1135,9 +1119,9 @@ impl AuthProviderTemplate { Ok(()) } - async fn update_cache(data: &web::Data) -> Result<(), ErrorResponse> { + async fn update_cache() -> Result<(), ErrorResponse> { Self::invalidate_cache().await?; - Self::get_all_json_template(data).await?; + Self::get_all_json_template().await?; Ok(()) } } @@ -1239,7 +1223,6 @@ impl AuthProviderIdClaims<'_> { async fn validate_update_user( &self, - data: &web::Data, provider: &AuthProvider, link_cookie: &Option, ) -> Result<(User, ProviderMfaLogin), ErrorResponse> { @@ -1270,7 +1253,7 @@ impl AuthProviderIdClaims<'_> { // Any json number would become a String too, which is what we need for compatibility. .to_string(); - let user_opt = match User::find_by_federation(data, &provider.id, &claims_user_id).await { + let user_opt = match User::find_by_federation(&provider.id, &claims_user_id).await { Ok(user) => { debug!( "found already existing user by federation lookup: {:?}", @@ -1286,7 +1269,7 @@ impl AuthProviderIdClaims<'_> { // On conflict, the DB would return an error anyway, but the error message is // rather cryptic for a normal user. if let Ok(mut user) = - User::find_by_email(data, self.email.as_ref().unwrap().to_string()).await + User::find_by_email(self.email.as_ref().unwrap().to_string()).await { // TODO check if creating a new link for an existing user is allowed if let Some(link) = link_cookie { @@ -1439,7 +1422,7 @@ impl AuthProviderIdClaims<'_> { user.last_failed_login = Some(now); user.failed_login_attempts = Some(user.failed_login_attempts.unwrap_or_default() + 1); - user.save(data, old_email).await?; + user.save(old_email).await?; return Err(ErrorResponse::new( ErrorResponseType::Forbidden, @@ -1492,7 +1475,7 @@ impl AuthProviderIdClaims<'_> { user.last_failed_login = None; user.failed_login_attempts = None; - user.save(data, old_email).await?; + user.save(old_email).await?; user } else { // Create a new federated user @@ -1521,12 +1504,12 @@ impl AuthProviderIdClaims<'_> { federation_uid: Some(claims_user_id.to_string()), ..Default::default() }; - User::create_federated(data, new_user).await? + User::create_federated(new_user).await? }; // check if we got additional values from the token let mut found_values = false; - let mut user_values = match UserValues::find(data, &user.id).await? { + let mut user_values = match UserValues::find(&user.id).await? { Some(values) => UserValuesRequest { birthdate: values.birthdate, phone: values.phone, @@ -1565,7 +1548,7 @@ impl AuthProviderIdClaims<'_> { found_values = true; } if found_values { - UserValues::upsert(data, user.id.clone(), user_values).await?; + UserValues::upsert(user.id.clone(), user_values).await?; } Ok((user, provider_mfa_login)) diff --git a/src/models/src/entity/clients.rs b/src/models/src/entity/clients.rs index e3573636d..d065dc5f8 100644 --- a/src/models/src/entity/clients.rs +++ b/src/models/src/entity/clients.rs @@ -80,10 +80,7 @@ impl Client { } // have less cloning - pub async fn create( - data: &web::Data, - mut client_req: NewClientRequest, - ) -> Result { + pub async fn create(mut client_req: NewClientRequest) -> Result { let kid = if client_req.confidential { let (_cleartext, enc) = Self::generate_new_secret()?; client_req.secret = Some(enc); @@ -158,7 +155,7 @@ $18, $19, $20)"#, client.client_uri, client.contacts, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -221,7 +218,7 @@ VALUES ($1, $2, $3, $4)"#, )) ]).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; sqlx::query!( r#" @@ -282,14 +279,14 @@ VALUES ($1, $2, $3, $4)"#, } // Deletes a client - pub async fn delete(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn delete(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM clients WHERE id = $1", params!(&self.id)) .await?; } else { sqlx::query!("DELETE FROM clients WHERE id = $1", self.id,) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -316,7 +313,7 @@ VALUES ($1, $2, $3, $4)"#, } // Returns a client by id without its secret. - pub async fn find(data: &web::Data, id: String) -> Result { + pub async fn find(id: String) -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, Self::cache_idx(&id)).await? { return Ok(slf); @@ -329,7 +326,7 @@ VALUES ($1, $2, $3, $4)"#, } else { sqlx::query_as::<_, Self>("SELECT * FROM clients WHERE id = $1") .bind(&id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -340,14 +337,14 @@ VALUES ($1, $2, $3, $4)"#, Ok(slf) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let clients = if is_hiqlite() { DB::client() .query_as("SELECT * FROM clients", params!()) .await? } else { sqlx::query_as("SELECT * FROM clients") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -358,12 +355,9 @@ VALUES ($1, $2, $3, $4)"#, /// If allowed, it will dynamically build an ephemeral client and cache it, it the client_id /// is a URL. Otherwise, it will do a classic fetch from the database. /// This function should be used in places where we would possibly accept an ephemeral client. - pub async fn find_maybe_ephemeral( - data: &web::Data, - id: String, - ) -> Result { + pub async fn find_maybe_ephemeral(id: String) -> Result { if !*ENABLE_EPHEMERAL_CLIENTS || Url::from_str(&id).is_err() { - return Self::find(data, id).await; + return Self::find(id).await; } let client = DB::client(); @@ -385,10 +379,7 @@ VALUES ($1, $2, $3, $4)"#, Ok(slf) } - pub async fn find_with_scope( - data: &web::Data, - scope_name: &str, - ) -> Result, ErrorResponse> { + pub async fn find_with_scope(scope_name: &str) -> Result, ErrorResponse> { let like = format!("%{scope_name}%"); let clients = if is_hiqlite() { @@ -401,7 +392,7 @@ VALUES ($1, $2, $3, $4)"#, } else { sqlx::query_as("SELECT * FROM clients WHERE scopes = $1 OR default_scopes = $1") .bind(like) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -485,7 +476,7 @@ WHERE id = $20"#, Ok(()) } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -550,7 +541,7 @@ WHERE id = $20"#, self.contacts, self.id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -572,7 +563,7 @@ WHERE id = $20"#, .unwrap_or_else(|| "client_secret_basic".to_string()); let mut new_client = Self::try_from_dyn_reg(client_req)?; - let current = Self::find(data, client_dyn.id.clone()).await?; + let current = Self::find(client_dyn.id.clone()).await?; if !current.is_dynamic() { return Err(ErrorResponse::new( ErrorResponseType::Forbidden, @@ -612,7 +603,7 @@ WHERE id = $4"#, DB::client().txn(txn).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; new_client.save_txn(&mut txn).await?; sqlx::query!( @@ -830,12 +821,9 @@ impl Client { /// Sanitizes the current scopes and deletes everything, which does not exist in the `scopes` /// table in the database - pub async fn sanitize_scopes( - data: &web::Data, - scps: Vec, - ) -> Result { + pub async fn sanitize_scopes(scps: Vec) -> Result { let mut res = String::with_capacity(scps.len()); - Scope::find_all(data).await?.into_iter().for_each(|s| { + Scope::find_all().await?.into_iter().for_each(|s| { if scps.contains(&s.name) { res.push_str(s.name.as_str()); res.push(','); diff --git a/src/models/src/entity/clients_dyn.rs b/src/models/src/entity/clients_dyn.rs index 7305916f1..572116f47 100644 --- a/src/models/src/entity/clients_dyn.rs +++ b/src/models/src/entity/clients_dyn.rs @@ -30,7 +30,7 @@ impl ClientDyn { Ok(()) } - pub async fn find(data: &web::Data, id: String) -> Result { + pub async fn find(id: String) -> Result { let client = DB::client(); if let Some(slf) = client @@ -46,7 +46,7 @@ impl ClientDyn { .await? } else { query_as!(Self, "SELECT * FROM clients_dyn WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -62,7 +62,7 @@ impl ClientDyn { Ok(slf) } - pub async fn update_used(data: &web::Data, id: &str) -> Result<(), ErrorResponse> { + pub async fn update_used(id: &str) -> Result<(), ErrorResponse> { let now = Utc::now().timestamp(); if is_hiqlite() { @@ -78,7 +78,7 @@ impl ClientDyn { now, id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/colors.rs b/src/models/src/entity/colors.rs index 7588d84d6..527db46b5 100644 --- a/src/models/src/entity/colors.rs +++ b/src/models/src/entity/colors.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; -use actix_web::web; use hiqlite::{params, Param}; use rauthy_api_types::clients::ColorsRequest; use rauthy_common::constants::CACHE_TTL_APP; @@ -18,7 +16,7 @@ pub struct ColorEntity { // CRUD impl ColorEntity { - pub async fn delete(data: &web::Data, client_id: &str) -> Result<(), ErrorResponse> { + pub async fn delete(client_id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -28,7 +26,7 @@ impl ColorEntity { .await?; } else { sqlx::query!("DELETE FROM colors WHERE client_id = $1", client_id,) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -39,10 +37,7 @@ impl ColorEntity { Ok(()) } - pub async fn find( - data: &web::Data, - client_id: &str, - ) -> Result { + pub async fn find(client_id: &str) -> Result { let idx = Self::cache_idx(client_id); let client = DB::client(); @@ -60,7 +55,7 @@ impl ColorEntity { .ok() } else { sqlx::query_as!(Self, "SELECT * FROM colors WHERE client_id = $1", client_id) - .fetch_optional(&data.db) + .fetch_optional(DB::conn()) .await? }; let colors = match res { @@ -73,15 +68,11 @@ impl ColorEntity { Ok(colors) } - pub async fn find_rauthy(data: &web::Data) -> Result { - Self::find(data, "rauthy").await + pub async fn find_rauthy() -> Result { + Self::find("rauthy").await } - pub async fn update( - data: &web::Data, - client_id: &str, - req: ColorsRequest, - ) -> Result<(), ErrorResponse> { + pub async fn update(client_id: &str, req: ColorsRequest) -> Result<(), ErrorResponse> { let cols = Colors::from(req); let col_bytes = cols.as_bytes(); @@ -106,7 +97,7 @@ SET data = $2"#, client_id, col_bytes, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/devices.rs b/src/models/src/entity/devices.rs index e6369f2ab..c48b066f6 100644 --- a/src/models/src/entity/devices.rs +++ b/src/models/src/entity/devices.rs @@ -1,7 +1,5 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; use crate::entity::refresh_tokens_devices::RefreshTokenDevice; -use actix_web::web; use chrono::{DateTime, Utc}; use hiqlite::{params, Param}; use rauthy_api_types::users::DeviceResponse; @@ -30,7 +28,7 @@ pub struct DeviceEntity { } impl DeviceEntity { - pub async fn insert(self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn insert(self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -65,44 +63,41 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, self.peer_ip, self.name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let slf = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM devices WHERE id = $1", params!(id)) .await? } else { query_as!(Self, "SELECT * FROM devices WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(slf) } - pub async fn find_for_user( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find_for_user(user_id: &str) -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as("SELECT * FROM devices WHERE user_id = $1", params!(user_id)) .await? } else { query_as!(Self, "SELECT * FROM devices WHERE user_id = $1", user_id) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } /// Deletes all devices where access and refresh token expirations are in the past - pub async fn delete_expired(data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn delete_expired() -> Result<(), ErrorResponse> { let exp = Utc::now() .sub(chrono::Duration::try_hours(1).unwrap()) .timestamp(); @@ -123,7 +118,7 @@ DELETE FROM devices WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, exp ) - .execute(&data.db) + .execute(DB::conn()) .await? .rows_affected(); res as usize @@ -133,14 +128,14 @@ WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, Ok(()) } - pub async fn invalidate(data: &web::Data, id: &str) -> Result<(), ErrorResponse> { + pub async fn invalidate(id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM devices WHERE id = $1", params!(id)) .await?; } else { query!("DELETE FROM devices WHERE id = $1", id) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -148,11 +143,8 @@ WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, Ok(()) } - pub async fn revoke_refresh_tokens( - data: &web::Data, - device_id: &str, - ) -> Result<(), ErrorResponse> { - RefreshTokenDevice::invalidate_all_for_device(data, device_id).await?; + pub async fn revoke_refresh_tokens(device_id: &str) -> Result<(), ErrorResponse> { + RefreshTokenDevice::invalidate_all_for_device(device_id).await?; if is_hiqlite() { DB::client() @@ -166,7 +158,7 @@ WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, "UPDATE devices SET refresh_exp = null WHERE id = $1", device_id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -174,7 +166,6 @@ WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, } pub async fn update_name( - data: &web::Data, device_id: &str, user_id: &str, name: &str, @@ -193,7 +184,7 @@ WHERE access_exp < $1 AND (refresh_exp < $1 OR refresh_exp is null)"#, device_id, user_id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/fed_cm.rs b/src/models/src/entity/fed_cm.rs index ef993e0f6..f4f7e5f3d 100644 --- a/src/models/src/entity/fed_cm.rs +++ b/src/models/src/entity/fed_cm.rs @@ -112,7 +112,7 @@ pub struct FedCMIdPBranding { impl FedCMIdPBranding { async fn new(data: &web::Data) -> Result { - let colors = ColorEntity::find_rauthy(data).await?; + let colors = ColorEntity::find_rauthy().await?; let rauthy_icon = FedCMIdPIcon::rauthy_logo(&data.issuer); Ok(Self { diff --git a/src/models/src/entity/groups.rs b/src/models/src/entity/groups.rs index 160ad02af..a4f5130a1 100644 --- a/src/models/src/entity/groups.rs +++ b/src/models/src/entity/groups.rs @@ -1,7 +1,5 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; use crate::entity::users::User; -use actix_web::web; use hiqlite::{params, Param, Params}; use rauthy_api_types::groups::NewGroupRequest; use rauthy_common::constants::{CACHE_TTL_APP, IDX_GROUPS}; @@ -21,11 +19,8 @@ pub struct Group { // CRUD impl Group { // Inserts a new group into the database - pub async fn create( - data: &web::Data, - group_req: NewGroupRequest, - ) -> Result { - let mut groups = Group::find_all(data).await?; + pub async fn create(group_req: NewGroupRequest) -> Result { + let mut groups = Group::find_all().await?; for g in &groups { if g.name == group_req.group { return Err(ErrorResponse::new( @@ -53,7 +48,7 @@ impl Group { new_group.id, new_group.name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -66,9 +61,9 @@ impl Group { } // Deletes a group - pub async fn delete(data: &web::Data, id: String) -> Result<(), ErrorResponse> { - let group = Group::find(data, id).await?; - let users = User::find_with_group(data, &group.name).await?; + pub async fn delete(id: String) -> Result<(), ErrorResponse> { + let group = Group::find(id).await?; + let users = User::find_with_group(&group.name).await?; if is_hiqlite() { let mut txn: Vec<(&str, Params)> = Vec::with_capacity(users.len() + 1); @@ -88,7 +83,7 @@ impl Group { debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for mut user in users { user.delete_group(&group.name); @@ -101,7 +96,7 @@ impl Group { txn.commit().await?; } - let groups = Group::find_all(data) + let groups = Group::find_all() .await? .into_iter() .filter(|g| g.id != group.id) @@ -119,14 +114,14 @@ impl Group { } // Returns a single group by id - pub async fn find(data: &web::Data, id: String) -> Result { + pub async fn find(id: String) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM groups WHERE id = $1", params!(id)) .await? } else { sqlx::query_as!(Self, "SELECT * FROM groups WHERE id = $1", id,) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -134,7 +129,7 @@ impl Group { } // Returns all existing groups - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_GROUPS).await? { return Ok(slf); @@ -144,7 +139,7 @@ impl Group { client.query_as("SELECT * FROM groups", params!()).await? } else { sqlx::query_as!(Self, "SELECT * FROM groups") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -156,13 +151,9 @@ impl Group { } // Updates a group - pub async fn update( - data: &web::Data, - id: String, - new_name: String, - ) -> Result { - let group = Group::find(data, id).await?; - let users = User::find_with_group(data, &group.name).await?; + pub async fn update(id: String, new_name: String) -> Result { + let group = Group::find(id).await?; + let users = User::find_with_group(&group.name).await?; let new_group = Self { id: group.id.clone(), @@ -192,7 +183,7 @@ impl Group { debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for mut user in users { user.delete_group(&group.name); @@ -209,7 +200,7 @@ impl Group { txn.commit().await?; } - let groups = Group::find_all(data) + let groups = Group::find_all() .await? .into_iter() .map(|mut g| { @@ -234,7 +225,6 @@ impl Group { // Sanitizes any bad data from an API request for adding / modifying groups and silently // dismissed all bad data. pub async fn sanitize( - data: &web::Data, groups_opt: Option>, ) -> Result, ErrorResponse> { if groups_opt.is_none() { @@ -243,7 +233,7 @@ impl Group { let groups = groups_opt.unwrap(); let mut res = String::with_capacity(groups.len()); - Group::find_all(data).await?.into_iter().for_each(|g| { + Group::find_all().await?.into_iter().for_each(|g| { if groups.contains(&g.name) { res.push_str(g.name.as_str()); res.push(','); diff --git a/src/models/src/entity/jwk.rs b/src/models/src/entity/jwk.rs index c5d9026ed..4f72a5fc4 100644 --- a/src/models/src/entity/jwk.rs +++ b/src/models/src/entity/jwk.rs @@ -172,7 +172,7 @@ pub struct JWKS { // CRUD impl JWKS { - pub async fn find_pk(data: &web::Data) -> Result { + pub async fn find_pk() -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_JWKS).await? { @@ -183,7 +183,7 @@ impl JWKS { client.query_as("SELECT * FROM jwks", params!()).await? } else { sqlx::query_as!(Jwk, "SELECT * FROM jwks") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -641,7 +641,7 @@ impl JwkKeyPair { } // Returns a JWK by a given Key Identifier (kid) - pub async fn find(data: &web::Data, kid: String) -> Result { + pub async fn find(kid: String) -> Result { let idx = format!("{}{}", IDX_JWK_KID, kid); let client = DB::client(); @@ -655,7 +655,7 @@ impl JwkKeyPair { .await? } else { sqlx::query_as!(Jwk, "SELECT * FROM jwks WHERE kid = $1", kid,) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -668,10 +668,7 @@ impl JwkKeyPair { // TODO add an index (signature, created_at) with the next round of DB migrations // Returns the latest JWK (especially important after a [JWK Rotation](crate::api::rotate_jwk) // by a given algorithm. - pub async fn find_latest( - data: &web::Data, - key_pair_alg: JwkKeyPairAlg, - ) -> Result { + pub async fn find_latest(key_pair_alg: JwkKeyPairAlg) -> Result { let idx = format!("{}{}", IDX_JWK_LATEST, key_pair_alg.as_str()); let client = DB::client(); @@ -703,7 +700,7 @@ LIMIT 1 "#, signature ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; diff --git a/src/models/src/entity/logos.rs b/src/models/src/entity/logos.rs index 1ad056bd0..6ccbcbd1c 100644 --- a/src/models/src/entity/logos.rs +++ b/src/models/src/entity/logos.rs @@ -1,4 +1,3 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; use actix_web::web; use hiqlite::{params, Param, Row}; @@ -128,11 +127,7 @@ impl<'r> From> for Logo { } impl Logo { - pub async fn delete( - data: &web::Data, - id: &str, - typ: &LogoType, - ) -> Result<(), ErrorResponse> { + pub async fn delete(id: &str, typ: &LogoType) -> Result<(), ErrorResponse> { match typ { LogoType::Client => { if is_hiqlite() { @@ -141,7 +136,7 @@ impl Logo { .await?; } else { query!("DELETE FROM client_logos WHERE client_id = $1", id) - .execute(&data.db) + .execute(DB::conn()) .await?; } } @@ -158,7 +153,7 @@ impl Logo { "DELETE FROM auth_provider_logos WHERE auth_provider_id = $1", id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } } @@ -172,7 +167,6 @@ impl Logo { } pub async fn upsert( - data: &web::Data, id: String, logo: Vec, content_type: mime::Mime, @@ -188,10 +182,8 @@ impl Logo { // technically not do an upsert, but actually delete + insert. match content_type.as_ref() { - "image/svg+xml" => { - Self::upsert_svg(data, id, logo, content_type.to_string(), &typ).await - } - "image/jpeg" | "image/png" => Self::upsert_jpg_png(data.clone(), id, logo, typ).await, + "image/svg+xml" => Self::upsert_svg(id, logo, content_type.to_string(), &typ).await, + "image/jpeg" | "image/png" => Self::upsert_jpg_png(id, logo, typ).await, _ => Err(ErrorResponse::new( ErrorResponseType::BadRequest, "Invalid mime type for auth provider logo", @@ -200,13 +192,12 @@ impl Logo { } async fn upsert_svg( - data: &web::Data, id: String, logo: Vec, content_type: String, typ: &LogoType, ) -> Result<(), ErrorResponse> { - Self::delete(data, &id, typ).await?; + Self::delete(&id, typ).await?; // SVG's don't have a resolution, save them as they are let slf = Self { @@ -215,16 +206,11 @@ impl Logo { content_type, data: logo, }; - slf.upsert_self(data, typ, true).await + slf.upsert_self(typ, true).await } - async fn upsert_jpg_png( - data: web::Data, - id: String, - logo: Vec, - typ: LogoType, - ) -> Result<(), ErrorResponse> { - Self::delete(&data, &id, &typ).await?; + async fn upsert_jpg_png(id: String, logo: Vec, typ: LogoType) -> Result<(), ErrorResponse> { + Self::delete(&id, &typ).await?; // we will save jpg / png in 2 downscaled and optimized resolutions: // - `RES_LATER_USE`px for possible later use @@ -268,7 +254,7 @@ impl Logo { content_type: CONTENT_TYPE_WEBP.to_string(), data: buf.into_inner(), }; - slf_medium.upsert_self(&data, &typ, false).await?; + slf_medium.upsert_self(&typ, false).await?; let img_small = image_medium.resize_to_fill(size_small, size_small, FilterType::Lanczos3); @@ -280,7 +266,7 @@ impl Logo { content_type: slf_medium.content_type, data: buf.into_inner(), } - .upsert_self(&data, &typ, true) + .upsert_self(&typ, true) .await?; Ok::<(), ErrorResponse>(()) @@ -292,9 +278,9 @@ impl Logo { } /// Overwrites the logo for the `rauthy` client with the default logo - pub async fn upsert_rauthy_default(data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn upsert_rauthy_default() -> Result<(), ErrorResponse> { // make sure to delete any possibly existing webp image before inserting the svg - Self::delete(data, "rauthy", &LogoType::Client).await?; + Self::delete("rauthy", &LogoType::Client).await?; Self { id: "rauthy".to_string(), @@ -302,16 +288,11 @@ impl Logo { content_type: mime::IMAGE_SVG.to_string(), data: RAUTHY_DEFAULT_SVG.as_bytes().to_vec(), } - .upsert_self(data, &LogoType::Client, true) + .upsert_self(&LogoType::Client, true) .await } - async fn upsert_self( - &self, - data: &web::Data, - typ: &LogoType, - with_cache: bool, - ) -> Result<(), ErrorResponse> { + async fn upsert_self(&self, typ: &LogoType, with_cache: bool) -> Result<(), ErrorResponse> { let res = self.res.as_str(); if is_hiqlite() { @@ -370,7 +351,7 @@ SET content_type = $3, data = $4"#, ) } } - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -388,12 +369,7 @@ SET content_type = $3, data = $4"#, Ok(()) } - pub async fn find( - data: &web::Data, - id: &str, - res: LogoRes, - typ: &LogoType, - ) -> Result { + pub async fn find(id: &str, res: LogoRes, typ: &LogoType) -> Result { let res = res.as_str(); let res_svg = LogoRes::Svg.as_str(); @@ -428,7 +404,7 @@ WHERE client_id = $1 AND (res = $2 OR res = $3)"#, res, res_svg, ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? } LogoType::AuthProvider => { @@ -442,7 +418,7 @@ WHERE auth_provider_id = $1 AND (res = $2 OR res = $3)"#, res, res_svg, ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? } } @@ -452,17 +428,13 @@ WHERE auth_provider_id = $1 AND (res = $2 OR res = $3)"#, } // special fn because we would only cache the small logos - pub async fn find_cached( - data: &web::Data, - id: &str, - typ: &LogoType, - ) -> Result { + pub async fn find_cached(id: &str, typ: &LogoType) -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, Self::cache_idx(typ, id)).await? { return Ok(slf); } - let slf = Self::find(data, id, LogoRes::Small, typ).await?; + let slf = Self::find(id, LogoRes::Small, typ).await?; client .put(Cache::App, Self::cache_idx(typ, id), &slf, CACHE_TTL_APP) diff --git a/src/models/src/entity/magic_links.rs b/src/models/src/entity/magic_links.rs index 39c206cee..f74d605b2 100644 --- a/src/models/src/entity/magic_links.rs +++ b/src/models/src/entity/magic_links.rs @@ -1,7 +1,6 @@ use crate::api_cookie::ApiCookie; -use crate::app_state::AppState; use crate::database::DB; -use actix_web::{web, HttpRequest}; +use actix_web::HttpRequest; use hiqlite::{params, Param}; use rauthy_common::constants::{PASSWORD_RESET_COOKIE_BINDING, PWD_CSRF_HEADER, PWD_RESET_COOKIE}; use rauthy_common::is_hiqlite; @@ -100,7 +99,6 @@ pub struct MagicLink { // CRUD impl MagicLink { pub async fn create( - data: &web::Data, user_id: String, lifetime_minutes: i64, usage: MagicLinkUsage, @@ -145,31 +143,28 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, false, link.usage, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(link) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM magic_links WHERE id = $1", params!(id)) .await? } else { sqlx::query_as!(Self, "SELECT * FROM magic_links WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(res) } - pub async fn find_by_user( - data: &web::Data, - user_id: String, - ) -> Result { + pub async fn find_by_user(user_id: String) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one( @@ -183,17 +178,14 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, "SELECT * FROM magic_links WHERE user_id = $1", user_id ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(res) } - pub async fn invalidate_all_email_change( - data: &web::Data, - user_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_all_email_change(user_id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -206,14 +198,14 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, "DELETE FROM magic_links WHERE user_id = $1 AND USAGE LIKE 'email_change$%'", user_id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; }; Ok(()) } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -229,7 +221,7 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, self.used, self.id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -238,9 +230,9 @@ VALUES ($1, $2, $3, $4, $5, $6)"#, } impl MagicLink { - pub async fn invalidate(&mut self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn invalidate(&mut self) -> Result<(), ErrorResponse> { self.exp = OffsetDateTime::now_utc().unix_timestamp() - 10; - self.save(data).await + self.save().await } pub fn validate( diff --git a/src/models/src/entity/mod.rs b/src/models/src/entity/mod.rs index cf09155a7..8e100e7c6 100644 --- a/src/models/src/entity/mod.rs +++ b/src/models/src/entity/mod.rs @@ -1,4 +1,3 @@ -use crate::app_state::DbPool; use crate::database::DB; use hiqlite::params; use rauthy_common::is_hiqlite; @@ -38,7 +37,7 @@ pub mod webauthn; pub mod webids; pub mod well_known; -pub async fn is_db_alive(db: &DbPool) -> bool { +pub async fn is_db_alive() -> bool { if is_hiqlite() { // execute returning instead of query to make sure the leader is reachable in HA deployment DB::client() @@ -46,6 +45,6 @@ pub async fn is_db_alive(db: &DbPool) -> bool { .await .is_ok() } else { - query("SELECT 1").execute(db).await.is_ok() + query("SELECT 1").execute(DB::conn()).await.is_ok() } } diff --git a/src/models/src/entity/password.rs b/src/models/src/entity/password.rs index a17ecab16..c4572f86d 100644 --- a/src/models/src/entity/password.rs +++ b/src/models/src/entity/password.rs @@ -1,4 +1,3 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; use actix_web::web; use argon2::password_hash::SaltString; @@ -125,7 +124,7 @@ pub struct PasswordPolicy { // CRUD impl PasswordPolicy { - pub async fn find(data: &web::Data) -> Result { + pub async fn find() -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_PASSWORD_RULES).await? { return Ok(slf); @@ -142,7 +141,7 @@ impl PasswordPolicy { .get("data") } else { sqlx::query("SELECT data FROM config WHERE id = 'password_policy'") - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? .get("data") }; @@ -155,8 +154,8 @@ impl PasswordPolicy { Ok(policy) } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { - let slf = bincode::serialize(&self).unwrap(); + pub async fn save(&self) -> Result<(), ErrorResponse> { + let slf = bincode::serialize(&self)?; if is_hiqlite() { DB::client() @@ -170,7 +169,7 @@ impl PasswordPolicy { "UPDATE config SET data = $1 WHERE id = 'password_policy'", slf ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -218,11 +217,7 @@ pub struct RecentPasswordsEntity { } impl RecentPasswordsEntity { - pub async fn create( - data: &web::Data, - user_id: &str, - passwords: String, - ) -> Result<(), ErrorResponse> { + pub async fn create(user_id: &str, passwords: String) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -236,14 +231,14 @@ impl RecentPasswordsEntity { user_id, passwords, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find(data: &web::Data, user_id: &str) -> Result { + pub async fn find(user_id: &str) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one( @@ -257,13 +252,13 @@ impl RecentPasswordsEntity { "SELECT * FROM recent_passwords WHERE user_id = $1", user_id, ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(res) } - pub async fn save(self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -277,7 +272,7 @@ impl RecentPasswordsEntity { self.passwords, self.user_id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) diff --git a/src/models/src/entity/refresh_tokens.rs b/src/models/src/entity/refresh_tokens.rs index 74384cbad..c462551bb 100644 --- a/src/models/src/entity/refresh_tokens.rs +++ b/src/models/src/entity/refresh_tokens.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::DB; -use actix_web::web; use chrono::{DateTime, Utc}; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; @@ -21,7 +19,6 @@ pub struct RefreshToken { // CRUD impl RefreshToken { pub async fn create( - data: &web::Data, id: String, user_id: String, nbf: DateTime, @@ -41,37 +38,37 @@ impl RefreshToken { is_mfa, }; - rt.save(data).await?; + rt.save().await?; Ok(rt) } - pub async fn delete(self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn delete(self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM refresh_tokens WHERE id = $1", params!(self.id)) .await?; } else { sqlx::query!("DELETE FROM refresh_tokens WHERE id = $1", self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as("SELECT * FROM refresh_tokens", params!()) .await? } else { sqlx::query_as!(Self, "SELECT * FROM refresh_tokens") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn invalidate_all(data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn invalidate_all() -> Result<(), ErrorResponse> { let now = Utc::now().timestamp(); if is_hiqlite() { @@ -83,17 +80,14 @@ impl RefreshToken { .await?; } else { sqlx::query!("UPDATE refresh_tokens SET exp = $1 WHERE exp > $1", now) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn invalidate_for_user( - data: &web::Data, - user_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_for_user(user_id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -103,13 +97,13 @@ impl RefreshToken { .await?; } else { sqlx::query!("DELETE FROM refresh_tokens WHERE user_id = $1", user_id) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let now = Utc::now().timestamp(); let slf = if is_hiqlite() { @@ -129,7 +123,7 @@ impl RefreshToken { id, now ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await .map_err(|_| { ErrorResponse::new(ErrorResponseType::NotFound, "Refresh Token does not exist") @@ -138,7 +132,7 @@ impl RefreshToken { Ok(slf) } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -169,7 +163,7 @@ ON CONFLICT(id) DO UPDATE SET user_id = $2, nbf = $3, exp = $4, scope = $5"#, self.scope, self.is_mfa, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -178,17 +172,14 @@ ON CONFLICT(id) DO UPDATE SET user_id = $2, nbf = $3, exp = $4, scope = $5"#, } impl RefreshToken { - pub async fn invalidate_all_for_user( - data: &web::Data, - id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_all_for_user(id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM refresh_tokens WHERE user_id = $1", params!(id)) .await?; } else { sqlx::query!("DELETE FROM refresh_tokens WHERE user_id = $1", id) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/refresh_tokens_devices.rs b/src/models/src/entity/refresh_tokens_devices.rs index 4e64800bd..92c948f3f 100644 --- a/src/models/src/entity/refresh_tokens_devices.rs +++ b/src/models/src/entity/refresh_tokens_devices.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::DB; -use actix_web::web; use chrono::{DateTime, Utc}; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; @@ -22,7 +20,6 @@ pub struct RefreshTokenDevice { // CRUD impl RefreshTokenDevice { pub async fn create( - data: &web::Data, id: String, device_id: String, user_id: String, @@ -39,11 +36,11 @@ impl RefreshTokenDevice { scope, }; - rt.save(data).await?; + rt.save().await?; Ok(rt) } - pub async fn delete(self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn delete(self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -53,26 +50,26 @@ impl RefreshTokenDevice { .await?; } else { sqlx::query!("DELETE FROM refresh_tokens_devices WHERE id = $1", self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as("SELECT * FROM refresh_tokens_devices", params!()) .await? } else { sqlx::query_as!(Self, "SELECT * FROM refresh_tokens_devices") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn invalidate_all(data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn invalidate_all() -> Result<(), ErrorResponse> { let now = OffsetDateTime::now_utc().unix_timestamp(); if is_hiqlite() { @@ -87,17 +84,14 @@ impl RefreshTokenDevice { "UPDATE refresh_tokens_devices SET exp = $1 WHERE exp > $1", now ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn invalidate_for_user( - data: &web::Data, - user_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_for_user(user_id: &str) -> Result<(), ErrorResponse> { let now = Utc::now().timestamp(); if is_hiqlite() { @@ -113,14 +107,14 @@ impl RefreshTokenDevice { now, user_id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let now = Utc::now().timestamp(); if is_hiqlite() { @@ -143,7 +137,7 @@ impl RefreshTokenDevice { id, now ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await .map_err(|_| { ErrorResponse::new( @@ -154,10 +148,7 @@ impl RefreshTokenDevice { } } - pub async fn invalidate_all_for_device( - data: &web::Data, - device_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_all_for_device(device_id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -170,17 +161,14 @@ impl RefreshTokenDevice { "DELETE FROM refresh_tokens_devices WHERE device_id = $1", device_id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn invalidate_all_for_user( - data: &web::Data, - user_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_all_for_user(user_id: &str) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -193,14 +181,14 @@ impl RefreshTokenDevice { "DELETE FROM refresh_tokens_devices WHERE user_id = $1", user_id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -235,7 +223,7 @@ SET device_id = $2, user_id = $3, nbf = $4, exp = $5, scope = $6"#, self.exp, self.scope, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/roles.rs b/src/models/src/entity/roles.rs index 7ac7daddd..501c13b56 100644 --- a/src/models/src/entity/roles.rs +++ b/src/models/src/entity/roles.rs @@ -1,7 +1,5 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; use crate::entity::users::User; -use actix_web::web; use hiqlite::{params, Param, Params}; use rauthy_api_types::roles::NewRoleRequest; use rauthy_common::constants::{CACHE_TTL_APP, IDX_ROLES}; @@ -21,11 +19,8 @@ pub struct Role { // CRUD impl Role { // Inserts a new role into the database - pub async fn create( - data: &web::Data, - role_req: NewRoleRequest, - ) -> Result { - let mut roles = Role::find_all(data).await?; + pub async fn create(role_req: NewRoleRequest) -> Result { + let mut roles = Role::find_all().await?; for s in &roles { if s.name == role_req.role { return Err(ErrorResponse::new( @@ -53,7 +48,7 @@ impl Role { new_role.id, new_role.name, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -66,8 +61,8 @@ impl Role { } // Deletes a role - pub async fn delete(data: &web::Data, id: &str) -> Result<(), ErrorResponse> { - let role = Role::find(data, id).await?; + pub async fn delete(id: &str) -> Result<(), ErrorResponse> { + let role = Role::find(id).await?; // prevent deletion of 'rauthy_admin' if role.name == "rauthy_admin" { @@ -77,7 +72,7 @@ impl Role { )); } - let users = User::find_with_role(data, &role.name).await?; + let users = User::find_with_role(&role.name).await?; if is_hiqlite() { let mut txn: Vec<(&str, Params)> = Vec::with_capacity(users.len() + 1); @@ -94,7 +89,7 @@ impl Role { debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for mut user in users { user.delete_role(&role.name); @@ -107,7 +102,7 @@ impl Role { txn.commit().await?; } - let roles = Role::find_all(data) + let roles = Role::find_all() .await? .into_iter() .filter(|r| r.id != role.id) @@ -125,14 +120,14 @@ impl Role { } // Returns a single role by id - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM roles WHERE id = $1", params!(id)) .await? } else { sqlx::query_as!(Self, "SELECT * FROM roles WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -140,7 +135,7 @@ impl Role { } // Returns all existing roles - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_ROLES).await? { return Ok(slf); @@ -152,7 +147,7 @@ impl Role { .await? } else { sqlx::query_as!(Self, "SELECT * FROM roles") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -163,13 +158,9 @@ impl Role { } // Updates a role - pub async fn update( - data: &web::Data, - id: String, - new_name: String, - ) -> Result { - let role = Role::find(data, &id).await?; - let users = User::find_with_role(data, &role.name).await?; + pub async fn update(id: String, new_name: String) -> Result { + let role = Role::find(&id).await?; + let users = User::find_with_role(&role.name).await?; let new_role = Self { id: role.id.clone(), @@ -194,7 +185,7 @@ impl Role { debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for mut user in users { user.delete_role(&role.name); @@ -211,7 +202,7 @@ impl Role { txn.commit().await?; } - let roles = Role::find_all(data) + let roles = Role::find_all() .await? .into_iter() .map(|mut r| { @@ -233,12 +224,9 @@ impl Role { } impl Role { - pub async fn sanitize( - data: &web::Data, - rls: Vec, - ) -> Result { + pub async fn sanitize(rls: Vec) -> Result { let mut res = String::with_capacity(rls.len()); - Role::find_all(data).await?.into_iter().for_each(|r| { + Role::find_all().await?.into_iter().for_each(|r| { if rls.contains(&r.name) { res.push_str(r.name.as_str()); res.push(','); diff --git a/src/models/src/entity/scopes.rs b/src/models/src/entity/scopes.rs index f06c4571d..f3b3c052f 100644 --- a/src/models/src/entity/scopes.rs +++ b/src/models/src/entity/scopes.rs @@ -38,7 +38,7 @@ impl Scope { scope_req: ScopeRequest, ) -> Result { // check for already existing scope - let mut scopes = Scope::find_all(data).await?; + let mut scopes = Scope::find_all().await?; for s in &scopes { if s.name == scope_req.scope { return Err(ErrorResponse::new( @@ -58,7 +58,7 @@ impl Scope { } // check configured custom attributes and clean them up - let attrs = UserAttrConfigEntity::find_all_as_set(data).await?; + let attrs = UserAttrConfigEntity::find_all_as_set().await?; let attr_include_access = Self::clean_up_attrs(scope_req.attr_include_access, &attrs); let attr_include_id = Self::clean_up_attrs(scope_req.attr_include_id, &attrs); @@ -93,7 +93,7 @@ VALUES ($1, $2, $3, $4)"#, new_scope.attr_include_access, new_scope.attr_include_id, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -108,7 +108,7 @@ VALUES ($1, $2, $3, $4)"#, } pub async fn delete(data: &web::Data, id: &str) -> Result<(), ErrorResponse> { - let scope = Scope::find(data, id).await?; + let scope = Scope::find(id).await?; if scope.name == "openid" { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -116,7 +116,7 @@ VALUES ($1, $2, $3, $4)"#, )); } - let mut clients = Client::find_with_scope(data, &scope.name).await?; + let mut clients = Client::find_with_scope(&scope.name).await?; if is_hiqlite() { let mut txn: Vec<(&str, Params)> = Vec::with_capacity(clients.len() + 1); @@ -132,7 +132,7 @@ VALUES ($1, $2, $3, $4)"#, debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for client in &mut clients { client.delete_scope(&scope.name); @@ -145,7 +145,7 @@ VALUES ($1, $2, $3, $4)"#, txn.commit().await?; } - let scopes = Scope::find_all(data) + let scopes = Scope::find_all() .await? .into_iter() .filter(|s| s.id != scope.id) @@ -170,21 +170,21 @@ VALUES ($1, $2, $3, $4)"#, Ok(()) } - pub async fn find(data: &web::Data, id: &str) -> Result { + pub async fn find(id: &str) -> Result { let res = if is_hiqlite() { DB::client() .query_as_one("SELECT * FROM scopes WHERE id = $1", params!(id)) .await? } else { sqlx::query_as!(Self, "SELECT * FROM scopes WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(res) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_SCOPES).await? { return Ok(slf); @@ -196,13 +196,14 @@ VALUES ($1, $2, $3, $4)"#, .await? } else { sqlx::query_as!(Self, "SELECT * FROM scopes") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; client .put(Cache::App, IDX_SCOPES, &res, CACHE_TTL_APP) .await?; + Ok(res) } @@ -212,7 +213,7 @@ VALUES ($1, $2, $3, $4)"#, id: &str, scope_req: ScopeRequest, ) -> Result { - let scope = Scope::find(data, id).await?; + let scope = Scope::find(id).await?; if scope.name == "openid" { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -232,7 +233,7 @@ VALUES ($1, $2, $3, $4)"#, // we only need to update clients with pre-computed values if the name // has been changed, but can skip them if it's about the attribute mapping let clients = if scope.name != scope_req.scope { - let clients = Client::find_with_scope(data, &scope.name) + let clients = Client::find_with_scope(&scope.name) .await? .into_iter() .map(|mut c| { @@ -249,7 +250,7 @@ VALUES ($1, $2, $3, $4)"#, debug!("scope_req: {:?}", scope_req); // check configured custom attributes and clean them up - let attrs = UserAttrConfigEntity::find_all_as_set(data).await?; + let attrs = UserAttrConfigEntity::find_all_as_set().await?; let attr_include_access = Self::clean_up_attrs(scope_req.attr_include_access, &attrs); let attr_include_id = Self::clean_up_attrs(scope_req.attr_include_id, &attrs); debug!("attr_include_access: {:?}", attr_include_access); @@ -289,7 +290,7 @@ WHERE id = $4"#, debug_assert!(rows_affected == 1); } } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; if let Some(clients) = &clients { for client in clients { @@ -318,7 +319,7 @@ WHERE id = $4"#, } } - let scopes = Scope::find_all(data) + let scopes = Scope::find_all() .await? .into_iter() .map(|mut s| { diff --git a/src/models/src/entity/sessions.rs b/src/models/src/entity/sessions.rs index 02c868391..d271fe7d3 100644 --- a/src/models/src/entity/sessions.rs +++ b/src/models/src/entity/sessions.rs @@ -1,5 +1,4 @@ use crate::api_cookie::ApiCookie; -use crate::app_state::AppState; use crate::database::{Cache, DB}; use crate::entity::continuation_token::ContinuationToken; use crate::entity::users::User; @@ -39,23 +38,6 @@ pub struct Session { pub remote_ip: Option, } -// impl<'r> From> for Session { -// fn from(mut row: hiqlite::Row<'r>) -> Self { -// Self { -// id: row.get("id"), -// csrf_token: row.get("csrf_token"), -// user_id: row.get("user_id"), -// roles: row.get("roles"), -// groups: row.get("groups"), -// is_mfa: row.get("is_mfa"), -// state: row.get("state"), -// exp: row.get("exp"), -// last_seen: row.get("last_seen"), -// remote_ip: row.get("remote_ip"), -// } -// } -// } - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum SessionState { Open, @@ -118,14 +100,14 @@ impl SessionState { // CRUD impl Session { - pub async fn delete(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn delete(&self) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute("DELETE FROM sessions WHERE id = $1", params!(&self.id)) .await?; } else { sqlx::query!("DELETE FROM sessions WHERE id = $1", self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -136,10 +118,7 @@ impl Session { Ok(()) } - pub async fn delete_by_user( - data: &web::Data, - user_id: &str, - ) -> Result<(), ErrorResponse> { + pub async fn delete_by_user(user_id: &str) -> Result<(), ErrorResponse> { let sids: Vec = if is_hiqlite() { let rows = DB::client() .execute_returning( @@ -158,7 +137,7 @@ impl Session { "DELETE FROM sessions WHERE user_id = $1 RETURNING id", user_id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await?; let mut ids = Vec::with_capacity(rows.len()); @@ -179,7 +158,7 @@ impl Session { } // Returns a session by id - pub async fn find(data: &web::Data, id: String) -> Result { + pub async fn find(id: String) -> Result { let idx = Session::cache_idx(&id); let client = DB::client(); @@ -200,7 +179,7 @@ impl Session { "SELECT * FROM sessions WHERE id = $1 ORDER BY exp DESC", id ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -212,21 +191,20 @@ impl Session { } // not cached -> only used in the admin ui and can get very big - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let sessions = if is_hiqlite() { DB::client() .query_as("SELECT * FROM sessions ORDER BY exp DESC", params!()) .await? } else { sqlx::query_as!(Self, "SELECT * FROM sessions ORDER BY exp DESC") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(sessions) } pub async fn find_paginated( - data: &web::Data, continuation_token: Option, page_size: i64, mut offset: i64, @@ -269,7 +247,7 @@ OFFSET $4"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -307,7 +285,7 @@ OFFSET $4"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -343,7 +321,7 @@ OFFSET $2"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -377,7 +355,7 @@ OFFSET $2"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -397,16 +375,16 @@ OFFSET $2"#, } /// Invalidates all sessions by setting the expiry to `now()` - pub async fn invalidate_all(data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn invalidate_all() -> Result<(), ErrorResponse> { let now = OffsetDateTime::now_utc().unix_timestamp(); - let sessions = Session::find_all(data).await?; + let sessions = Session::find_all().await?; let mut removed = Vec::default(); // TODO refactor into single query with `RETURNING` for mut s in sessions { if s.exp > now { s.exp = now; - if let Err(err) = s.save(data).await { + if let Err(err) = s.save().await { error!("Error invalidating session: {}", err); } removed.push(s.id); @@ -424,10 +402,7 @@ OFFSET $2"#, } /// If any sessions have been deleted, `Vec` will be returned for cache invalidation. - pub async fn invalidate_for_user( - data: &web::Data, - uid: &str, - ) -> Result<(), ErrorResponse> { + pub async fn invalidate_for_user(uid: &str) -> Result<(), ErrorResponse> { let sids: Vec = if is_hiqlite() { let rows = DB::client() .execute_returning( @@ -444,7 +419,7 @@ OFFSET $2"#, } else { let rows = sqlx::query("DELETE FROM sessions WHERE user_id = $1 RETURNING id") .bind(uid) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await?; let mut ids = Vec::with_capacity(rows.len()); @@ -463,7 +438,7 @@ OFFSET $2"#, } /// Saves a Session - pub async fn save(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { let state_str = &self.state; if is_hiqlite() { @@ -510,7 +485,7 @@ remote_ip = $10"#, self.last_seen, self.remote_ip, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } @@ -528,7 +503,6 @@ remote_ip = $10"#, /// Caution: Uses regex / LIKE on the database -> very costly query pub async fn search( - data: &web::Data, idx: &SearchParamsIdx, q: &str, limit: i64, @@ -549,7 +523,7 @@ remote_ip = $10"#, q, limit ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -568,7 +542,7 @@ remote_ip = $10"#, q, limit ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -585,7 +559,7 @@ remote_ip = $10"#, q, limit ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -723,10 +697,7 @@ impl Session { ) } - pub async fn invalidate( - &mut self, - data: &web::Data, - ) -> Result { + pub async fn invalidate(&mut self) -> Result { let idx = Session::cache_idx(&self.id); self.exp = OffsetDateTime::now_utc().unix_timestamp(); @@ -736,7 +707,7 @@ impl Session { .bind(self.exp) .bind(self.state.as_str()) .bind(&self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; DB::client().delete(Cache::Session, idx).await?; @@ -816,13 +787,9 @@ impl Session { } #[inline] - pub async fn set_mfa( - &mut self, - data: &web::Data, - value: bool, - ) -> Result<(), ErrorResponse> { + pub async fn set_mfa(&mut self, value: bool) -> Result<(), ErrorResponse> { self.is_mfa = value; - self.save(data).await + self.save().await } #[inline(always)] diff --git a/src/models/src/entity/user_attr.rs b/src/models/src/entity/user_attr.rs index f21d4560f..173547a9a 100644 --- a/src/models/src/entity/user_attr.rs +++ b/src/models/src/entity/user_attr.rs @@ -1,8 +1,7 @@ -use crate::app_state::{AppState, DbTxn}; +use crate::app_state::DbTxn; use crate::database::{Cache, DB}; use crate::entity::scopes::Scope; use crate::entity::users::User; -use actix_web::web; use hiqlite::{params, Param, Params}; use rauthy_api_types::users::{ UserAttrConfigRequest, UserAttrConfigValueResponse, UserAttrValueResponse, @@ -38,11 +37,8 @@ impl UserAttrConfigEntity { Ok(()) } - pub async fn create( - data: &web::Data, - new_attr: UserAttrConfigRequest, - ) -> Result { - if Self::find(data, new_attr.name.clone()).await.is_ok() { + pub async fn create(new_attr: UserAttrConfigRequest) -> Result { + if Self::find(new_attr.name.clone()).await.is_ok() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, "User attribute config does already exist", @@ -62,11 +58,11 @@ impl UserAttrConfigEntity { new_attr.name, new_attr.desc, ) - .execute(&data.db) + .execute(DB::conn()) .await?; }; - let mut attrs = UserAttrConfigEntity::find_all(data).await?; + let mut attrs = UserAttrConfigEntity::find_all().await?; let slf = Self { name: new_attr.name.clone(), desc: new_attr.desc.clone(), @@ -79,12 +75,12 @@ impl UserAttrConfigEntity { Ok(slf) } - pub async fn delete(data: &web::Data, name: String) -> Result<(), ErrorResponse> { + pub async fn delete(name: String) -> Result<(), ErrorResponse> { // we do this empty check beforehand to avoid much more unnecessary work if it does not exist anyway - let slf = Self::find(data, name.clone()).await?; + let slf = Self::find(name.clone()).await?; // delete all possible scope mappings - let scopes = Scope::find_all(data).await?; + let scopes = Scope::find_all().await?; let mut scope_updates = Vec::new(); for s in scopes { let mut needs_update = false; @@ -152,7 +148,7 @@ impl UserAttrConfigEntity { client.txn(txn).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for (id, attr_include_access, attr_include_id) in scope_updates { Scope::update_mapping_only(&id, attr_include_access, attr_include_id, &mut txn) @@ -160,7 +156,7 @@ impl UserAttrConfigEntity { } user_attr_cache_cleanup_keys = - UserAttrValueEntity::delete_all_by_key(data, &name, &mut txn).await?; + UserAttrValueEntity::delete_all_by_key(&name, &mut txn).await?; sqlx::query!("DELETE FROM user_attr_config WHERE name = $1", name) .execute(&mut *txn) @@ -182,7 +178,7 @@ impl UserAttrConfigEntity { Ok(()) } - pub async fn find(data: &web::Data, name: String) -> Result { + pub async fn find(name: String) -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::App, Self::cache_idx(&name)).await? { return Ok(slf); @@ -197,7 +193,7 @@ impl UserAttrConfigEntity { .await? } else { sqlx::query_as!(Self, "SELECT * FROM user_attr_config WHERE name = $1", name) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -208,7 +204,7 @@ impl UserAttrConfigEntity { Ok(slf) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let client = DB::client(); if let Some(slf) = client.get(Cache::App, IDX_USER_ATTR_CONFIG).await? { return Ok(slf); @@ -220,7 +216,7 @@ impl UserAttrConfigEntity { .await? } else { sqlx::query_as!(Self, "SELECT * FROM user_attr_config") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -232,11 +228,10 @@ impl UserAttrConfigEntity { } pub async fn update( - data: &web::Data, name: String, req_data: UserAttrConfigRequest, ) -> Result { - let mut slf = Self::find(data, name.clone()).await?; + let mut slf = Self::find(name.clone()).await?; slf.name.clone_from(&req_data.name); slf.desc.clone_from(&req_data.desc); @@ -259,7 +254,7 @@ impl UserAttrConfigEntity { } else { sqlx::query("SELECT user_id FROM user_attr_values WHERE key = $1") .bind(&name) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? .into_iter() .map(|row| row.get("user_id")) @@ -267,7 +262,7 @@ impl UserAttrConfigEntity { }; // update all possible scope mappings - let scopes = Scope::find_all(data).await?; + let scopes = Scope::find_all().await?; for scope in scopes { let mut needs_update = false; @@ -327,7 +322,7 @@ impl UserAttrConfigEntity { client.txn(txn).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for (id, attr_include_access, attr_include_id) in scope_updates { Scope::update_mapping_only(&id, attr_include_access, attr_include_id, &mut txn) @@ -361,10 +356,8 @@ impl UserAttrConfigEntity { } impl UserAttrConfigEntity { - pub async fn find_all_as_set( - data: &web::Data, - ) -> Result, ErrorResponse> { - let attrs = Self::find_all(data).await?; + pub async fn find_all_as_set() -> Result, ErrorResponse> { + let attrs = Self::find_all().await?; let mut set = HashSet::with_capacity(attrs.len()); for a in attrs { @@ -410,13 +403,12 @@ impl UserAttrValueEntity { /// You MUST `UserAttrValueEntity::clear_cache()` for returned UserCacheKeys /// after successful txn commit! pub async fn delete_all_by_key( - data: &web::Data, key: &str, txn: &mut DbTxn<'_>, ) -> Result, ErrorResponse> { let cache_idxs = sqlx::query_as!(Self, "SELECT * FROM user_attr_values WHERE key = $1", key) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? .into_iter() .map(|a| Self::cache_idx(&a.user_id)) @@ -450,10 +442,7 @@ impl UserAttrValueEntity { Ok(cache_idxs) } - pub async fn find_for_user( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find_for_user(user_id: &str) -> Result, ErrorResponse> { let idx = Self::cache_idx(user_id); let client = DB::client(); @@ -474,7 +463,7 @@ impl UserAttrValueEntity { "SELECT * FROM user_attr_values WHERE user_id = $1", user_id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -484,13 +473,12 @@ impl UserAttrValueEntity { } pub async fn update_for_user( - data: &web::Data, user_id: &str, req_data: UserAttrValuesUpdateRequest, ) -> Result, ErrorResponse> { // Not necessary for the operation and correctness, but look up the user first and return // an error, if it does not exist at all, for a better user experience. - User::exists(data, user_id.to_string()).await?; + User::exists(user_id.to_string()).await?; let delete_value = |value: &Value| { if let Some(s) = value.as_str() { @@ -531,7 +519,7 @@ ON CONFLICT(user_id, key) DO UPDATE SET value = $3"#, ) .await? } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; for value in req_data.values { if delete_value(&value.value) { @@ -565,7 +553,7 @@ ON CONFLICT(user_id, key) DO UPDATE SET value = $3"#, "SELECT * FROM user_attr_values WHERE user_id = $1", user_id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; diff --git a/src/models/src/entity/users.rs b/src/models/src/entity/users.rs index cd315de88..0dec76941 100644 --- a/src/models/src/entity/users.rs +++ b/src/models/src/entity/users.rs @@ -89,7 +89,6 @@ pub struct User { // CRUD impl User { - /// Invalidates the cache for the given user pub async fn invalidate_cache(user_id: &str, email: &str) -> Result<(), ErrorResponse> { let client = DB::client(); @@ -102,7 +101,7 @@ impl User { Ok(()) } - pub async fn count(data: &web::Data) -> Result { + pub async fn count() -> Result { let client = DB::client(); if let Some(count) = client.get(Cache::App, IDX_USER_COUNT).await? { @@ -117,7 +116,7 @@ impl User { .get("count") } else { sqlx::query!("SELECT COUNT (*) count FROM users") - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? .count .unwrap_or_default() @@ -130,8 +129,8 @@ impl User { Ok(count) } - async fn count_inc(data: &web::Data) -> Result<(), ErrorResponse> { - let mut count = Self::count(data).await?; + async fn count_inc() -> Result<(), ErrorResponse> { + let mut count = Self::count().await?; // theoretically, we could have overlaps here, but we don't really care // -> used for dynamic pagination only and SQLite has limited query features count += 1; @@ -141,8 +140,8 @@ impl User { Ok(()) } - async fn count_dec(data: &web::Data) -> Result<(), ErrorResponse> { - let mut count = Self::count(data).await?; + async fn count_dec() -> Result<(), ErrorResponse> { + let mut count = Self::count().await?; // theoretically, we could have overlaps here, but we don't really care // -> used for dynamic pagination only and SQLite has limited query features count -= 1; @@ -152,16 +151,14 @@ impl User { Ok(()) } - // Inserts a user into the database pub async fn create( data: &web::Data, new_user: User, post_reset_redirect_uri: Option, ) -> Result { - let slf = Self::insert(data, new_user).await?; + let slf = Self::insert(new_user).await?; let magic_link = MagicLink::create( - data, slf.id.clone(), data.ml_lt_pwd_first as i64, MagicLinkUsage::NewUser(post_reset_redirect_uri), @@ -172,23 +169,19 @@ impl User { Ok(slf) } - pub async fn create_federated( - data: &web::Data, - new_user: User, - ) -> Result { - Self::insert(data, new_user).await + pub async fn create_federated(new_user: User) -> Result { + Self::insert(new_user).await } - // Inserts a user into the database pub async fn create_from_new( data: &web::Data, new_user_req: NewUserRequest, ) -> Result { - let new_user = User::from_new_user_req(data, new_user_req).await?; + let new_user = User::from_new_user_req(new_user_req).await?; User::create(data, new_user, None).await } - // Inserts a user from the open registration endpoint into the database + /// Inserts a user from the open registration endpoint into the database pub async fn create_from_reg( data: &web::Data, req_data: NewUserRegistrationRequest, @@ -206,9 +199,8 @@ impl User { Ok(new_user) } - // Deletes a user - pub async fn delete(&self, data: &web::Data) -> Result<(), ErrorResponse> { - Session::delete_by_user(data, &self.id).await?; + pub async fn delete(&self) -> Result<(), ErrorResponse> { + Session::delete_by_user(&self.id).await?; let client = DB::client(); if is_hiqlite() { @@ -217,18 +209,17 @@ impl User { .await?; } else { sqlx::query!("DELETE FROM users WHERE id = $1", self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; } Self::invalidate_cache(&self.id, &self.email).await?; - Self::count_dec(data).await?; + Self::count_dec().await?; Ok(()) } - // Checks if a user exists in the database without fetching data - pub async fn exists(data: &web::Data, id: String) -> Result<(), ErrorResponse> { + pub async fn exists(id: String) -> Result<(), ErrorResponse> { let idx = format!("{}_{}", IDX_USERS, id); let opt: Option = DB::client().get(Cache::User, idx).await?; @@ -248,14 +239,14 @@ impl User { } } else { sqlx::query!("SELECT id FROM users WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await?; } Ok(()) } - pub async fn find(data: &web::Data, id: String) -> Result { + pub async fn find(id: String) -> Result { let idx = format!("{}_{}", IDX_USERS, id); let client = DB::client(); @@ -269,7 +260,7 @@ impl User { .await? } else { sqlx::query_as!(Self, "SELECT * FROM users WHERE id = $1", id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -277,10 +268,7 @@ impl User { Ok(slf) } - pub async fn find_by_email( - data: &web::Data, - email: String, - ) -> Result { + pub async fn find_by_email(email: String) -> Result { let email = email.to_lowercase(); let idx = format!("{}_{}", IDX_USERS, email); @@ -296,7 +284,7 @@ impl User { .await? } else { sqlx::query_as!(Self, "SELECT * FROM users WHERE email = $1", email) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -305,7 +293,6 @@ impl User { } pub async fn find_by_federation( - data: &web::Data, auth_provider_id: &str, federation_uid: &str, ) -> Result { @@ -323,30 +310,28 @@ impl User { auth_provider_id, federation_uid ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; Ok(slf) } - pub async fn find_all(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_all() -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as("SELECT * FROM users ORDER BY created_at ASC", params!()) .await? } else { sqlx::query_as!(Self, "SELECT * FROM users ORDER BY created_at ASC") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn find_all_simple( - data: &web::Data, - ) -> Result, ErrorResponse> { + pub async fn find_all_simple() -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_as( @@ -359,7 +344,7 @@ impl User { UserResponseSimple, "SELECT id, email, created_at, last_login FROM users ORDER BY created_at ASC" ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -367,10 +352,7 @@ impl User { } /// This is a very expensive query using `LIKE`, use only when necessary. - pub async fn find_with_group( - data: &web::Data, - group_name: &str, - ) -> Result, ErrorResponse> { + pub async fn find_with_group(group_name: &str) -> Result, ErrorResponse> { let like = format!("%{group_name}%"); let res = if is_hiqlite() { @@ -379,7 +361,7 @@ impl User { .await? } else { sqlx::query_as!(Self, "SELECT * FROM users WHERE groups LIKE $1", like) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -387,10 +369,7 @@ impl User { } /// This is a very expensive query using `LIKE`, use only when necessary. - pub async fn find_with_role( - data: &web::Data, - role_name: &str, - ) -> Result, ErrorResponse> { + pub async fn find_with_role(role_name: &str) -> Result, ErrorResponse> { let like = format!("%{role_name}%"); let res = if is_hiqlite() { @@ -399,14 +378,14 @@ impl User { .await? } else { sqlx::query_as!(Self, "SELECT * FROM users WHERE roles LIKE $1", like) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn find_expired(data: &web::Data) -> Result, ErrorResponse> { + pub async fn find_expired() -> Result, ErrorResponse> { let now = Utc::now().add(chrono::Duration::seconds(10)).timestamp(); let res = if is_hiqlite() { @@ -415,21 +394,18 @@ impl User { .await? } else { sqlx::query_as!(Self, "SELECT * FROM users WHERE user_expires < $1", now) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn find_for_fed_cm_validated( - data: &web::Data, - user_id: String, - ) -> Result { + pub async fn find_for_fed_cm_validated(user_id: String) -> Result { // We will stick to the WWW-Authenticate header for now and use duplicated code from // some OAuth2 api for now until the spec has settled on an error behavior. debug!("Looking up FedCM user_id {}", user_id); - let slf = Self::find(data, user_id).await.map_err(|_| { + let slf = Self::find(user_id).await.map_err(|_| { debug!("FedCM user not found"); ErrorResponse::new( ErrorResponseType::WWWAuthenticate("user-not-found".to_string()), @@ -450,7 +426,6 @@ impl User { } pub async fn find_paginated( - data: &web::Data, continuation_token: Option, page_size: i64, mut offset: i64, @@ -491,7 +466,7 @@ OFFSET $4"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await?; res.reverse(); @@ -527,7 +502,7 @@ OFFSET $4"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -562,7 +537,7 @@ OFFSET $2"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await?; res.reverse(); @@ -594,7 +569,7 @@ OFFSET $2"#, page_size, offset, ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } }; @@ -606,7 +581,7 @@ OFFSET $2"#, Ok((res, token)) } - pub async fn insert(data: &web::Data, new_user: User) -> Result { + pub async fn insert(new_user: User) -> Result { let lang = new_user.language.as_str(); if is_hiqlite() { @@ -657,22 +632,19 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"#, new_user.auth_provider_id, new_user.federation_uid, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } - Self::count_inc(data).await?; + Self::count_inc().await?; Ok(new_user) } - pub async fn provider_unlink( - data: &web::Data, - user_id: String, - ) -> Result { + pub async fn provider_unlink(user_id: String) -> Result { // we need to find the user first and validate that it has been set up properly // to work without a provider - let mut slf = Self::find(data, user_id).await?; + let mut slf = Self::find(user_id).await?; if slf.password.is_none() && !slf.has_webauthn_enabled() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -682,7 +654,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)"#, slf.auth_provider_id = None; slf.federation_uid = None; - slf.save(data, None).await?; + slf.save(None).await?; Ok(slf) } @@ -770,13 +742,9 @@ WHERE id = $18"#, Ok(()) } - pub async fn save( - &self, - data: &web::Data, - old_email: Option, - ) -> Result<(), ErrorResponse> { + pub async fn save(&self, old_email: Option) -> Result<(), ErrorResponse> { if old_email.is_some() { - User::is_email_free(data, self.email.clone()).await?; + User::is_email_free(self.email.clone()).await?; } let lang = self.language.as_str(); @@ -842,13 +810,13 @@ WHERE id = $18"#, .bind(&self.auth_provider_id) .bind(&self.federation_uid) .bind(&self.id) - .execute(&data.db) + .execute(DB::conn()) .await?; } if !self.enabled { - Session::invalidate_for_user(data, &self.id).await?; - RefreshToken::invalidate_for_user(data, &self.id).await?; + Session::invalidate_for_user(&self.id).await?; + RefreshToken::invalidate_for_user(&self.id).await?; } if let Some(email) = old_email { @@ -867,7 +835,6 @@ WHERE id = $18"#, /// Caution: Uses regex / LIKE on the database -> very costly query pub async fn search( - data: &web::Data, idx: &SearchParamsIdx, q: &str, limit: i64, @@ -900,7 +867,7 @@ LIMIT $2"#, q, limit ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -929,7 +896,7 @@ LIMIT $2"#, q, limit ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? } } @@ -945,7 +912,6 @@ LIMIT $2"#, } pub async fn set_email_verified( - data: &web::Data, user_id: String, email_verified: bool, ) -> Result<(), ErrorResponse> { @@ -962,7 +928,7 @@ LIMIT $2"#, email_verified, user_id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) @@ -975,7 +941,7 @@ LIMIT $2"#, user: Option, ) -> Result<(User, Option, bool), ErrorResponse> { let mut user = match user { - None => User::find(data, id).await?, + None => User::find(id).await?, Some(user) => user, }; upd_user.email = upd_user.email.to_lowercase(); @@ -994,18 +960,18 @@ LIMIT $2"#, } if let Some(password) = &upd_user.password { - user.apply_password_rules(data, password).await?; + user.apply_password_rules(password).await?; } let is_admin_before_update = user.is_admin(); - user.roles = Role::sanitize(data, upd_user.roles).await?; - user.groups = Group::sanitize(data, upd_user.groups).await?; + user.roles = Role::sanitize(upd_user.roles).await?; + user.groups = Group::sanitize(upd_user.groups).await?; user.enabled = upd_user.enabled; user.email_verified = upd_user.email_verified; user.user_expires = upd_user.user_expires; - user.save(data, old_email.clone()).await?; + user.save(old_email.clone()).await?; if upd_user.password.is_some() { data.tx_events @@ -1020,7 +986,7 @@ LIMIT $2"#, if let Some(old_email) = old_email.as_ref() { // if the user was saved successfully and the email was changed, invalidate all existing // sessions with the old address and send out notifications to the users addresses - Session::invalidate_for_user(data, &user.id).await?; + Session::invalidate_for_user(&user.id).await?; // send out confirmation E-Mails to both addresses send_email_confirm_change(data, &user, &user.email, &user.email, true).await; @@ -1037,7 +1003,7 @@ LIMIT $2"#, // finally, update the custom users values let user_values = if let Some(values) = upd_user.user_values { - UserValues::upsert(data, user.id.clone(), values).await? + UserValues::upsert(user.id.clone(), values).await? } else { None }; @@ -1045,7 +1011,7 @@ LIMIT $2"#, Ok((user, user_values, is_new_admin)) } - pub async fn update_language(&self, data: &web::Data) -> Result<(), ErrorResponse> { + pub async fn update_language(&self) -> Result<(), ErrorResponse> { let lang = self.language.as_str(); if is_hiqlite() { @@ -1061,21 +1027,21 @@ LIMIT $2"#, lang, self.id ) - .execute(&data.db) + .execute(DB::conn()) .await?; } Ok(()) } - // Updates a user from himself. This is needed for the account page to make each user able to - // update its own data. + /// Updates a user from himself. This is needed for the account page to make each user able to + /// update its own data. pub async fn update_self_req( data: &web::Data, id: String, upd_user: UpdateUserSelfRequest, ) -> Result<(User, Option, bool), ErrorResponse> { - let user = User::find(data, id.clone()).await?; + let user = User::find(id.clone()).await?; let mut password = None; if let Some(pwd_new) = upd_user.password_new { @@ -1104,10 +1070,9 @@ LIMIT $2"#, // email to old AND new address if email != user.email { // invalidate possibly other existing MagicLinks of the same type - MagicLink::invalidate_all_email_change(data, &user.id).await?; + MagicLink::invalidate_all_email_change(&user.id).await?; let ml = MagicLink::create( - data, user.id.clone(), 60, MagicLinkUsage::EmailChange(email.clone()), @@ -1159,11 +1124,8 @@ LIMIT $2"#, /// Converts a user account from as password account type to passkey only with all necessary /// checks included. - pub async fn convert_to_passkey( - data: &web::Data, - id: String, - ) -> Result<(), ErrorResponse> { - let mut user = User::find(data, id.clone()).await?; + pub async fn convert_to_passkey(id: String) -> Result<(), ErrorResponse> { + let mut user = User::find(id.clone()).await?; if user.account_type() != AccountType::Password { return Err(ErrorResponse::new( @@ -1179,7 +1141,7 @@ LIMIT $2"#, )); } - let pks = PasskeyEntity::find_for_user_with_uv(data, &user.id).await?; + let pks = PasskeyEntity::find_for_user_with_uv(&user.id).await?; if pks.is_empty() { return Err(ErrorResponse::new( ErrorResponseType::NotFound, @@ -1190,7 +1152,7 @@ LIMIT $2"#, user.password = None; user.password_expires = None; - user.save(data, None).await?; + user.save(None).await?; Ok(()) } } @@ -1215,12 +1177,8 @@ impl User { } } - pub async fn apply_password_rules( - &mut self, - data: &web::Data, - plain_pwd: &str, - ) -> Result<(), ErrorResponse> { - let rules = PasswordPolicy::find(data).await?; + pub async fn apply_password_rules(&mut self, plain_pwd: &str) -> Result<(), ErrorResponse> { + let rules = PasswordPolicy::find().await?; if plain_pwd.len() < rules.length_min as usize { return Err(ErrorResponse::new( @@ -1300,7 +1258,7 @@ impl User { let mut new_recent = Vec::new(); if let Some(recent_req) = rules.not_recently_used { - match RecentPasswordsEntity::find(data, &self.id).await { + match RecentPasswordsEntity::find(&self.id).await { Ok(mut most_recent) => { let mut iteration = 1; for old_hash in most_recent.passwords.split('\n') { @@ -1326,11 +1284,11 @@ impl User { } most_recent.passwords = format!("{}\n{}", new_hash, new_recent.join("\n")); - most_recent.save(data).await?; + most_recent.save().await?; } Err(_) => { - RecentPasswordsEntity::create(data, &self.id, new_hash.clone()).await?; + RecentPasswordsEntity::create(&self.id, new_hash.clone()).await?; } } } @@ -1383,7 +1341,7 @@ impl User { user_id: String, confirm_id: String, ) -> Result { - let mut ml = MagicLink::find(data, &confirm_id).await?; + let mut ml = MagicLink::find(&confirm_id).await?; ml.validate(&user_id, &req, false)?; let usage = MagicLinkUsage::try_from(&ml.usage)?; @@ -1398,10 +1356,10 @@ impl User { MagicLinkUsage::EmailChange(email) => email, }; - let mut user = Self::find(data, user_id).await?; + let mut user = Self::find(user_id).await?; // build response HTML - let colors = ColorEntity::find_rauthy(data).await?; + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); let html = UserEmailChangeConfirmHtml::build(&colors, &lang, &user.email, &new_email); @@ -1409,11 +1367,11 @@ impl User { let old_email = user.email; user.email = new_email; user.email_verified = true; - user.save(data, Some(old_email.clone())).await?; - ml.invalidate(data).await?; + user.save(Some(old_email.clone())).await?; + ml.invalidate().await?; // finally, invalidate all existing sessions with the old email - Session::invalidate_for_user(data, &user.id).await?; + Session::invalidate_for_user(&user.id).await?; // send out confirmation E-Mails to both addresses send_email_confirm_change(data, &user, &user.email, &user.email, false).await; @@ -1485,12 +1443,9 @@ impl User { } } - pub async fn from_new_user_req( - data: &web::Data, - new_user: NewUserRequest, - ) -> Result { - let roles = Role::sanitize(data, new_user.roles).await?; - let groups = Group::sanitize(data, new_user.groups).await?; + pub async fn from_new_user_req(new_user: NewUserRequest) -> Result { + let roles = Role::sanitize(new_user.roles).await?; + let groups = Group::sanitize(new_user.groups).await?; let user = Self { email: new_user.email.to_lowercase(), @@ -1598,8 +1553,8 @@ impl User { self.get_roles().contains(&RAUTHY_ADMIN_ROLE) } - async fn is_email_free(data: &web::Data, email: String) -> Result<(), ErrorResponse> { - match User::find_by_email(data, email).await { + async fn is_email_free(email: String) -> Result<(), ErrorResponse> { + match User::find_by_email(email).await { Ok(_) => Err(ErrorResponse::new( ErrorResponseType::BadRequest, "E-Mail is already in use".to_string(), @@ -1650,7 +1605,7 @@ impl User { return Ok(()); } - let ml_res = MagicLink::find_by_user(data, self.id.clone()).await; + let ml_res = MagicLink::find_by_user(self.id.clone()).await; // if an active magic link already exists - invalidate it. if let Ok(mut ml) = ml_res { if ml.exp > OffsetDateTime::now_utc().unix_timestamp() { @@ -1658,7 +1613,7 @@ impl User { "Password reset request with already existing valid magic link from: {}", real_ip_from_req(&req)? ); - ml.invalidate(data).await?; + ml.invalidate().await?; } } @@ -1667,8 +1622,7 @@ impl User { } else { MagicLinkUsage::PasswordReset(redirect_uri) }; - let new_ml = - MagicLink::create(data, self.id.clone(), data.ml_lt_pwd_reset as i64, usage).await?; + let new_ml = MagicLink::create(self.id.clone(), data.ml_lt_pwd_reset as i64, usage).await?; send_pwd_reset(data, &new_ml, self).await; Ok(()) @@ -1699,7 +1653,6 @@ impl User { // if the given password does match, send out a reset link to set a new one return if self.match_passwords(plain_password.clone()).await? { let magic_link = MagicLink::create( - data, self.id.clone(), data.ml_lt_pwd_reset as i64, MagicLinkUsage::PasswordReset(None), diff --git a/src/models/src/entity/users_values.rs b/src/models/src/entity/users_values.rs index 0d97a9c28..5e4825a6a 100644 --- a/src/models/src/entity/users_values.rs +++ b/src/models/src/entity/users_values.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; -use actix_web::web; use hiqlite::{params, Param}; use jwt_simple::prelude::{Deserialize, Serialize}; use rauthy_api_types::users::{UserValuesRequest, UserValuesResponse}; @@ -26,10 +24,7 @@ impl UserValues { format!("{}_{}", IDX_USERS_VALUES, user_id) } - pub async fn find( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find(user_id: &str) -> Result, ErrorResponse> { let idx = Self::cache_idx(user_id); let client = DB::client(); @@ -45,7 +40,7 @@ impl UserValues { } else { sqlx::query_as::<_, Self>("SELECT * FROM users_values WHERE id = $1") .bind(user_id) - .fetch_optional(&data.db) + .fetch_optional(DB::conn()) .await? }; @@ -55,7 +50,6 @@ impl UserValues { } pub async fn upsert( - data: &web::Data, user_id: String, values: UserValuesRequest, ) -> Result, ErrorResponse> { @@ -95,7 +89,7 @@ SET birthdate = $2, phone = $3, street = $4, zip = $5, city = $6, country = $7"# values.city, values.country, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/webauthn.rs b/src/models/src/entity/webauthn.rs index fbe885008..f29d4eb1d 100644 --- a/src/models/src/entity/webauthn.rs +++ b/src/models/src/entity/webauthn.rs @@ -50,7 +50,6 @@ pub struct PasskeyEntity { impl PasskeyEntity { /// If the `User` is `Some(_)`, a `User::save()` will be included in the `txn` pub async fn create( - data: &web::Data, user_id: String, user: Option, passkey_user_id: Uuid, @@ -103,7 +102,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, client.txn(txn).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; if let Some(user) = user { debug_assert!(user.webauthn_user_id.is_some()); @@ -149,10 +148,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(()) } - pub async fn count_for_user( - data: &web::Data, - user_id: String, - ) -> Result { + pub async fn count_for_user(user_id: String) -> Result { let count: i64 = if is_hiqlite() { DB::client() .query_raw_one( @@ -166,7 +162,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, "SELECT COUNT (*) AS count FROM passkeys WHERE user_id = $1", user_id ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? .count .unwrap_or_default() @@ -175,24 +171,20 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(count) } - pub async fn delete( - data: &web::Data, - user_id: String, - name: String, - ) -> Result<(), ErrorResponse> { + pub async fn delete(user_id: String, name: String) -> Result<(), ErrorResponse> { // if we delete a passkey, we must check if this is the last existing one for the user - let pk_count = Self::count_for_user(data, user_id.clone()).await?; + let pk_count = Self::count_for_user(user_id.clone()).await?; let mut user_to_save: Option = None; let mut user_email: Option = None; if pk_count < 2 { - let mut user = User::find(data, user_id.clone()).await?; + let mut user = User::find(user_id.clone()).await?; user.webauthn_user_id = None; // in this case, we need to check against the current password policy, // if the password should expire again - let policy = PasswordPolicy::find(data).await?; + let policy = PasswordPolicy::find().await?; if let Some(valid_days) = policy.valid_days { if user.password.is_some() { user.password_expires = Some( @@ -219,7 +211,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, DB::client().txn(txn).await?; } else { - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; if let Some(user) = user_to_save { user.save_txn(&mut txn).await?; @@ -286,11 +278,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(()) } - pub async fn find( - data: &web::Data, - user_id: &str, - name: &str, - ) -> Result { + pub async fn find(user_id: &str, name: &str) -> Result { let idx = Self::cache_idx_single(user_id, name); let client = DB::client(); @@ -312,7 +300,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, user_id, name, ) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await? }; @@ -323,10 +311,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(slf) } - pub async fn find_cred_ids_for_user( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find_cred_ids_for_user(user_id: &str) -> Result, ErrorResponse> { let idx = Self::cache_idx_creds(user_id); let client = DB::client(); @@ -350,7 +335,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, "SELECT credential_id FROM passkeys WHERE user_id = $1", user_id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? .into_iter() .map(|row| row.credential_id) @@ -364,10 +349,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(creds.into_iter().map(CredentialID::from).collect()) } - pub async fn find_for_user( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find_for_user(user_id: &str) -> Result, ErrorResponse> { let idx = Self::cache_idx_user(user_id); let client = DB::client(); @@ -384,7 +366,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, .await? } else { sqlx::query_as!(Self, "SELECT * FROM passkeys WHERE user_id = $1", user_id) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -395,10 +377,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, Ok(pks) } - pub async fn find_for_user_with_uv( - data: &web::Data, - user_id: &str, - ) -> Result, ErrorResponse> { + pub async fn find_for_user_with_uv(user_id: &str) -> Result, ErrorResponse> { let idx = Self::cache_idx_user_with_uv(user_id); let client = DB::client(); @@ -419,7 +398,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, "SELECT * FROM passkeys WHERE user_id = $1 AND user_verified = true", user_id ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await? }; @@ -759,17 +738,17 @@ pub async fn auth_start( MfaPurpose::Test => WebauthnAdditionalData::Test, }; - let user = User::find(data, user_id).await?; + let user = User::find(user_id).await?; let force_uv = user.account_type() == AccountType::Passkey || *WEBAUTHN_FORCE_UV; let pks = if force_uv { // in this case, filter out all presence only keys - PasskeyEntity::find_for_user_with_uv(data, &user.id) + PasskeyEntity::find_for_user_with_uv(&user.id) .await? .iter() .map(|pk_entity| pk_entity.get_pk()) .collect::>() } else { - PasskeyEntity::find_for_user(data, &user.id) + PasskeyEntity::find_for_user(&user.id) .await? .iter() .map(|pk_entity| pk_entity.get_pk()) @@ -826,10 +805,10 @@ pub async fn auth_finish( let auth_data = WebauthnData::find(req.code).await?; let auth_state = serde_json::from_str(&auth_data.auth_state_json)?; - let mut user = User::find(data, user_id).await?; + let mut user = User::find(user_id).await?; let force_uv = user.account_type() == AccountType::Passkey || *WEBAUTHN_FORCE_UV; - let pks = PasskeyEntity::find_for_user(data, &user.id).await?; + let pks = PasskeyEntity::find_for_user(&user.id).await?; match data .webauthn @@ -860,7 +839,7 @@ pub async fn auth_finish( user.last_failed_login = None; user.failed_login_attempts = None; - let mut txn = data.db.begin().await?; + let mut txn = DB::txn().await?; pk_entity.update_passkey(&mut txn).await?; user.save_txn(&mut txn).await?; txn.commit().await?; @@ -893,13 +872,13 @@ pub async fn reg_start( user_id: String, req: WebauthnRegStartRequest, ) -> Result { - let user = User::find(data, user_id).await?; + let user = User::find(user_id).await?; let passkey_user_id = if let Some(id) = &user.webauthn_user_id { Uuid::from_str(id).expect("corrupted database: user.webauthn_user_id") } else { Uuid::new_v4() }; - let cred_ids = PasskeyEntity::find_cred_ids_for_user(data, &user.id).await?; + let cred_ids = PasskeyEntity::find_cred_ids_for_user(&user.id).await?; match data.webauthn.start_passkey_registration( passkey_user_id, @@ -955,7 +934,7 @@ pub async fn reg_finish( id: String, req: WebauthnRegFinishRequest, ) -> Result<(), ErrorResponse> { - let mut user = User::find(data, id).await?; + let mut user = User::find(id).await?; let idx = format!("reg_{:?}_{}", req.passkey_name, user.id); let client = DB::client(); @@ -1004,7 +983,6 @@ pub async fn reg_finish( }; PasskeyEntity::create( - data, user_id.clone(), create_user, reg_data.passkey_user_id, diff --git a/src/models/src/entity/webids.rs b/src/models/src/entity/webids.rs index 9f0689166..93dbdf379 100644 --- a/src/models/src/entity/webids.rs +++ b/src/models/src/entity/webids.rs @@ -1,6 +1,4 @@ -use crate::app_state::AppState; use crate::database::{Cache, DB}; -use actix_web::web; use hiqlite::{params, Param}; use rauthy_api_types::users::WebIdResponse; use rauthy_common::constants::{CACHE_TTL_USER, PUB_URL_WITH_SCHEME}; @@ -33,7 +31,7 @@ impl WebId { } /// Returns the WebId from the database, if it exists, and a default otherwise. - pub async fn find(data: &web::Data, user_id: String) -> Result { + pub async fn find(user_id: String) -> Result { let client = DB::client(); if let Some(slf) = client.get(Cache::User, Self::cache_idx(&user_id)).await? { return Ok(slf); @@ -50,7 +48,7 @@ impl WebId { }) } else { query_as!(Self, "SELECT * FROM webids WHERE user_id = $1", user_id) - .fetch_one(&data.db) + .fetch_one(DB::conn()) .await .unwrap_or(Self { user_id, @@ -71,7 +69,7 @@ impl WebId { Ok(slf) } - pub async fn upsert(data: &web::Data, web_id: WebId) -> Result<(), ErrorResponse> { + pub async fn upsert(web_id: WebId) -> Result<(), ErrorResponse> { if is_hiqlite() { DB::client() .execute( @@ -94,7 +92,7 @@ SET custom_triples = $2, expose_email = $3"#, web_id.custom_triples, web_id.expose_email, ) - .execute(&data.db) + .execute(DB::conn()) .await?; } diff --git a/src/models/src/entity/well_known.rs b/src/models/src/entity/well_known.rs index d4b5f4f50..b1894213c 100644 --- a/src/models/src/entity/well_known.rs +++ b/src/models/src/entity/well_known.rs @@ -45,13 +45,13 @@ impl WellKnown { return Ok(slf); } - let scopes = Scope::find_all(data) + let scopes = Scope::find_all() .await? .into_iter() .map(|s| s.name) .collect::>(); let slf = Self::new(&data.issuer, scopes); - let json = serde_json::to_string(&slf).unwrap(); + let json = serde_json::to_string(&slf)?; client.put(Cache::App, IDX, &json, CACHE_TTL_APP).await?; @@ -61,13 +61,13 @@ impl WellKnown { /// Rebuilds the WellKnown, serializes it as json and updates it inside the cache. /// Should be called after any update on the Scopes. pub async fn rebuild(data: &web::Data) -> Result<(), ErrorResponse> { - let scopes = Scope::find_all(data) + let scopes = Scope::find_all() .await? .into_iter() .map(|s| s.name) .collect::>(); let slf = Self::new(&data.issuer, scopes); - let json = serde_json::to_string(&slf).unwrap(); + let json = serde_json::to_string(&slf)?; DB::client() .put(Cache::App, IDX, &json, CACHE_TTL_APP) diff --git a/src/models/src/events/event.rs b/src/models/src/events/event.rs index fd6eb395a..7d5d7f32b 100644 --- a/src/models/src/events/event.rs +++ b/src/models/src/events/event.rs @@ -1,4 +1,3 @@ -use crate::app_state::DbPool; use crate::database::DB; use crate::events::{ EVENT_LEVEL_FAILED_LOGIN, EVENT_LEVEL_FAILED_LOGINS_10, EVENT_LEVEL_FAILED_LOGINS_15, @@ -440,7 +439,7 @@ impl Display for Event { } impl Event { - pub async fn insert(&self, db: &DbPool) -> Result<(), ErrorResponse> { + pub async fn insert(&self) -> Result<(), ErrorResponse> { let level = self.level.value(); let typ = self.typ.value(); @@ -474,7 +473,7 @@ VALUES ($1, $2, $3, $4, $5, $6, $7)"#, self.data, self.text, ) - .execute(db) + .execute(DB::conn()) .await?; } @@ -482,7 +481,6 @@ VALUES ($1, $2, $3, $4, $5, $6, $7)"#, } pub async fn find_all( - db: &DbPool, mut from: i64, mut until: i64, level: EventLevel, @@ -519,7 +517,7 @@ ORDER BY timestamp DESC"#, level, typ, ) - .fetch_all(db) + .fetch_all(DB::conn()) .await? } } else if is_hiqlite() { @@ -543,14 +541,14 @@ ORDER BY timestamp DESC"#, until, level, ) - .fetch_all(db) + .fetch_all(DB::conn()) .await? }; Ok(res) } - pub async fn find_latest(db: &DbPool, limit: i64) -> Result, ErrorResponse> { + pub async fn find_latest(limit: i64) -> Result, ErrorResponse> { let res = if is_hiqlite() { DB::client() .query_map( @@ -564,7 +562,7 @@ ORDER BY timestamp DESC"#, "SELECT * FROM events ORDER BY timestamp DESC LIMIT $1", limit ) - .fetch_all(db) + .fetch_all(DB::conn()) .await? }; diff --git a/src/models/src/events/health_watch.rs b/src/models/src/events/health_watch.rs index 8c98764f7..c46087dd3 100644 --- a/src/models/src/events/health_watch.rs +++ b/src/models/src/events/health_watch.rs @@ -1,11 +1,10 @@ -use crate::app_state::DbPool; use crate::database::DB; use crate::entity::is_db_alive; use crate::events::event::Event; use std::time::Duration; use tracing::debug; -pub async fn watch_health(db: DbPool, tx_events: flume::Sender) { +pub async fn watch_health(tx_events: flume::Sender) { debug!("Rauthy health watcher started"); let mut interval = tokio::time::interval(Duration::from_secs(30)); @@ -17,12 +16,12 @@ pub async fn watch_health(db: DbPool, tx_events: flume::Sender) { let cache_healthy = DB::client().is_healthy_cache().await.is_ok(); - let db_healthy = if !is_db_alive(&db).await { + let db_healthy = if !is_db_alive().await { // wait for a few seconds and try again before alerting tokio::time::sleep(Duration::from_secs(10)).await; // do not send - if !is_db_alive(&db).await && was_healthy_after_startup { + if !is_db_alive().await && was_healthy_after_startup { tx_events .send_async(Event::rauthy_unhealthy_db()) .await diff --git a/src/models/src/events/listener.rs b/src/models/src/events/listener.rs index e2698ae31..07b757366 100644 --- a/src/models/src/events/listener.rs +++ b/src/models/src/events/listener.rs @@ -1,4 +1,3 @@ -use crate::app_state::DbPool; use crate::database::DB; use crate::events::event::{Event, EventLevel, EventType}; use crate::events::ip_blacklist_handler::{IpBlacklist, IpBlacklistReq, IpLoginFailedSet}; @@ -34,25 +33,24 @@ impl EventListener { tx_router: flume::Sender, rx_router: flume::Receiver, rx_event: flume::Receiver, - db: DbPool, ) -> Result<(), ErrorResponse> { debug!("EventListener::listen has been started"); - tokio::spawn(Self::router(db.clone(), rx_router, tx_ip_blacklist)); + tokio::spawn(Self::router(rx_router, tx_ip_blacklist)); tokio::spawn(Self::raft_events_listener(tx_router)); while let Ok(event) = rx_event.recv_async().await { - tokio::spawn(Self::handle_event(event, db.clone())); + tokio::spawn(Self::handle_event(event)); } Ok(()) } #[tracing::instrument(level = "debug", skip_all)] - async fn handle_event(event: Event, db: DbPool) { + async fn handle_event(event: Event) { // insert into DB if &event.level.value() >= EVENT_PERSIST_LEVEL.get().unwrap() { - while let Err(err) = event.insert(&db).await { + while let Err(err) = event.insert().await { error!("Inserting Event into Database: {:?}", err); time::sleep(Duration::from_secs(1)).await; } @@ -145,7 +143,6 @@ impl EventListener { /// format and forward them to all registered clients. #[tracing::instrument(level = "debug", skip_all)] async fn router( - db: DbPool, rx: flume::Receiver, tx_ip_blacklist: flume::Sender, ) { @@ -155,7 +152,7 @@ impl EventListener { HashMap::with_capacity(4); let mut ips_to_remove = Vec::with_capacity(1); // Event::find_latest returns the latest events ordered by timestamp desc - let mut events = Event::find_latest(&db, EVENTS_LATEST_LIMIT as i64) + let mut events = Event::find_latest(EVENTS_LATEST_LIMIT as i64) .await .unwrap_or_default() .into_iter() diff --git a/src/schedulers/src/app_version.rs b/src/schedulers/src/app_version.rs index ab687c485..a80cf216e 100644 --- a/src/schedulers/src/app_version.rs +++ b/src/schedulers/src/app_version.rs @@ -47,9 +47,7 @@ async fn check_app_version( match LatestAppVersion::lookup().await { Ok((latest_version, url)) => { - if let Err(err) = - LatestAppVersion::upsert(data, latest_version.clone(), url.clone()).await - { + if let Err(err) = LatestAppVersion::upsert(latest_version.clone(), url.clone()).await { error!("Inserting LatestAppVersion into database: {:?}", err); } @@ -65,7 +63,7 @@ async fn check_app_version( info!("A new Rauthy App Version is available: {}", latest_version); if let Err(err) = - LatestAppVersion::upsert(data, latest_version.clone(), url.clone()).await + LatestAppVersion::upsert(latest_version.clone(), url.clone()).await { error!("Saving LatestAppVersion into DB: {:?}", err); } diff --git a/src/schedulers/src/backup.rs b/src/schedulers/src/backup.rs deleted file mode 100644 index 4edfe3da9..000000000 --- a/src/schedulers/src/backup.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::sleep_schedule_next; -use rauthy_common::constants::DB_TYPE; -use rauthy_common::DbType; -use rauthy_models::app_state::DbPool; -use rauthy_models::migration::backup_db; -use std::env; -use std::str::FromStr; -use tracing::{debug, error, info}; - -/// Creates a backup of the data store -pub async fn db_backup(db: DbPool) { - if *DB_TYPE == DbType::Postgres { - debug!("Using Postgres as the main database - automatic backups disabled"); - return; - } - - let mut cron_task = env::var("BACKUP_TASK").unwrap_or_else(|_| "0 0 4 * * * *".to_string()); - - // sec min hour day_of_month month day_of_week year - let schedule = cron::Schedule::from_str(&cron_task).unwrap_or_else(|err| { - error!( - "Error creating a cron scheduler with the given BACKUP_TASK input: {} - using default \"0 0 4 * * * *\": {}", - cron_task, err - ); - cron_task = "0 0 4 * * * *".to_string(); - cron::Schedule::from_str(&cron_task).unwrap() - }); - - info!("Database backups are scheduled for: {}", cron_task); - - loop { - sleep_schedule_next(&schedule).await; - debug!("Running db_backup scheduler"); - - // VACUUM main INTO 'data/rauthy.db.backup-' - - if let Err(err) = backup_db(&db).await { - error!("{}", err.message); - } - } -} diff --git a/src/schedulers/src/devices.rs b/src/schedulers/src/devices.rs index 5b4475603..92836ee49 100644 --- a/src/schedulers/src/devices.rs +++ b/src/schedulers/src/devices.rs @@ -1,7 +1,6 @@ use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; -use rauthy_models::app_state::DbPool; use rauthy_models::database::DB; use std::ops::Sub; use std::time::Duration; @@ -9,7 +8,7 @@ use tracing::{debug, error}; /// Cleans up fully expired devices. These need to do a full re-authentication anyway. /// All devices that are expired for at least 1 day will be removed. -pub async fn devices_cleanup(db: DbPool) { +pub async fn devices_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(24 * 3600)); loop { @@ -50,7 +49,7 @@ AND (refresh_exp is null OR refresh_exp < $1)"#, AND (refresh_exp is null OR refresh_exp < $1)"#, threshold ) - .execute(&db) + .execute(DB::conn()) .await; match res { diff --git a/src/schedulers/src/dyn_clients.rs b/src/schedulers/src/dyn_clients.rs index 9b0976db3..0d0c65a63 100644 --- a/src/schedulers/src/dyn_clients.rs +++ b/src/schedulers/src/dyn_clients.rs @@ -1,4 +1,3 @@ -use actix_web::web; use chrono::Utc; use hiqlite::params; use rauthy_common::constants::{ @@ -6,7 +5,6 @@ use rauthy_common::constants::{ ENABLE_DYN_CLIENT_REG, }; use rauthy_common::is_hiqlite; -use rauthy_models::app_state::AppState; use rauthy_models::database::DB; use rauthy_models::entity::clients::Client; use rauthy_models::entity::clients_dyn::ClientDyn; @@ -15,7 +13,7 @@ use std::time::Duration; use tracing::{debug, error, info}; /// Cleans up unused dynamically registered clients -pub async fn dyn_client_cleanup(data: web::Data) { +pub async fn dyn_client_cleanup() { if !*ENABLE_DYN_CLIENT_REG { info!( "Dynamic client registration is not enabled - exiting dynamic_client_cleanup scheduler" @@ -55,7 +53,7 @@ pub async fn dyn_client_cleanup(data: web::Data) { ClientDyn, "SELECT * FROM clients_dyn WHERE last_used = null" ) - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await .map_err(|err| err.to_string()) }; @@ -72,9 +70,9 @@ pub async fn dyn_client_cleanup(data: web::Data) { for client in clients { if client.created < threshold { info!("Cleaning up unused dynamic client {}", client.id); - match Client::find(&data, client.id).await { + match Client::find(client.id).await { Ok(c) => { - if let Err(err) = c.delete(&data).await { + if let Err(err) = c.delete().await { error!("Error deleting unused client: {:?}", err); continue; } diff --git a/src/schedulers/src/events.rs b/src/schedulers/src/events.rs index 0c535ac26..2831f487a 100644 --- a/src/schedulers/src/events.rs +++ b/src/schedulers/src/events.rs @@ -1,7 +1,6 @@ use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; -use rauthy_models::app_state::DbPool; use rauthy_models::database::DB; use std::env; use std::ops::Sub; @@ -9,7 +8,7 @@ use std::time::Duration; use tracing::{debug, error}; /// Cleans up all Events that exceed the configured EVENT_CLEANUP_DAYS -pub async fn events_cleanup(db: DbPool) { +pub async fn events_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(3600)); let cleanup_days = env::var("EVENT_CLEANUP_DAYS") @@ -47,7 +46,7 @@ pub async fn events_cleanup(db: DbPool) { } } else { let res = sqlx::query!("DELETE FROM events WHERE timestamp < $1", threshold) - .execute(&db) + .execute(DB::conn()) .await; match res { diff --git a/src/schedulers/src/jwks.rs b/src/schedulers/src/jwks.rs index 0de75b214..53e3d05b2 100644 --- a/src/schedulers/src/jwks.rs +++ b/src/schedulers/src/jwks.rs @@ -33,7 +33,7 @@ pub async fn jwks_auto_rotate(data: web::Data) { } /// Cleans up old / expired JWKSs -pub async fn jwks_cleanup(data: web::Data) { +pub async fn jwks_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(3600 * 24)); loop { @@ -57,7 +57,7 @@ pub async fn jwks_cleanup(data: web::Data) { .map_err(|err| err.to_string()) } else { sqlx::query_as::<_, Jwk>("SELECT * FROM jwks ORDER BY created_at asc") - .fetch_all(&data.db) + .fetch_all(DB::conn()) .await .map_err(|err| err.to_string()) }; @@ -105,7 +105,7 @@ pub async fn jwks_cleanup(data: web::Data) { } } else if let Err(err) = sqlx::query("DELETE FROM jwks WHERE kid = $1") .bind(&kid) - .execute(&data.db) + .execute(DB::conn()) .await { error!("Cannot clean up JWK {} in jwks_cleanup: {}", kid, err); diff --git a/src/schedulers/src/lib.rs b/src/schedulers/src/lib.rs index 358fc1fd1..0bf0e20c4 100644 --- a/src/schedulers/src/lib.rs +++ b/src/schedulers/src/lib.rs @@ -5,7 +5,6 @@ use tokio::time; use tracing::info; mod app_version; -mod backup; mod devices; mod dyn_clients; mod events; @@ -20,21 +19,16 @@ mod users; pub async fn spawn(data: web::Data) { info!("Starting schedulers"); - // TODO remove after hiqlite migration - has backup functionality built in - // initialize and possibly panic early if anything is mis-configured regarding the s3 storage - // s3_backup_init_test().await; - // tokio::spawn(backup::db_backup(data.db.clone())); - - tokio::spawn(dyn_clients::dyn_client_cleanup(data.clone())); - tokio::spawn(events::events_cleanup(data.db.clone())); - tokio::spawn(devices::devices_cleanup(data.db.clone())); - tokio::spawn(magic_links::magic_link_cleanup(data.db.clone())); - tokio::spawn(tokens::refresh_tokens_cleanup(data.db.clone())); - tokio::spawn(sessions::sessions_cleanup(data.db.clone())); + tokio::spawn(dyn_clients::dyn_client_cleanup()); + tokio::spawn(events::events_cleanup()); + tokio::spawn(devices::devices_cleanup()); + tokio::spawn(magic_links::magic_link_cleanup()); + tokio::spawn(tokens::refresh_tokens_cleanup()); + tokio::spawn(sessions::sessions_cleanup()); tokio::spawn(jwks::jwks_auto_rotate(data.clone())); - tokio::spawn(jwks::jwks_cleanup(data.clone())); + tokio::spawn(jwks::jwks_cleanup()); tokio::spawn(passwords::password_expiry_checker(data.clone())); - tokio::spawn(users::user_expiry_checker(data.clone())); + tokio::spawn(users::user_expiry_checker()); tokio::spawn(app_version::app_version_check(data)); } diff --git a/src/schedulers/src/magic_links.rs b/src/schedulers/src/magic_links.rs index 060afbb38..2661077b2 100644 --- a/src/schedulers/src/magic_links.rs +++ b/src/schedulers/src/magic_links.rs @@ -2,7 +2,6 @@ use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; use rauthy_error::ErrorResponse; -use rauthy_models::app_state::DbPool; use rauthy_models::database::DB; use std::ops::Sub; use std::time::Duration; @@ -11,7 +10,7 @@ use tracing::{debug, error}; /// Cleans up old / expired magic links and deletes users, that have never used their /// 'set first ever password' magic link to keep the database clean in case of an open user registration. /// Runs every 6 hours. -pub async fn magic_link_cleanup(db: DbPool) { +pub async fn magic_link_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(3600 * 6)); loop { @@ -36,7 +35,7 @@ pub async fn magic_link_cleanup(db: DbPool) { if let Err(err) = cleanup_hiqlite(exp).await { error!("{:?}", err); } - } else if let Err(err) = cleanup_sqlx(&db, exp).await { + } else if let Err(err) = cleanup_sqlx(exp).await { error!("{:?}", err); } } @@ -69,7 +68,7 @@ AND password IS NULL"#, Ok(()) } -async fn cleanup_sqlx(db: &DbPool, exp: i64) -> Result<(), ErrorResponse> { +async fn cleanup_sqlx(exp: i64) -> Result<(), ErrorResponse> { let res = sqlx::query( r#" DELETE FROM users @@ -80,7 +79,7 @@ WHERE id IN ( AND password IS NULL"#, ) .bind(exp) - .execute(db) + .execute(DB::conn()) .await?; debug!( "Cleaned up {} users which did not use their initial password reset magic link", @@ -90,7 +89,7 @@ AND password IS NULL"#, // now we can just delete all expired magic links let res = sqlx::query("DELETE FROM magic_links WHERE exp < $1") .bind(exp) - .execute(db) + .execute(DB::conn()) .await?; debug!( "Cleaned up {} expired and used magic links", diff --git a/src/schedulers/src/passwords.rs b/src/schedulers/src/passwords.rs index 7d98ae1ea..890030bea 100644 --- a/src/schedulers/src/passwords.rs +++ b/src/schedulers/src/passwords.rs @@ -33,7 +33,7 @@ pub async fn password_expiry_checker(data: web::Data) { let lower = now.add(chrono::Duration::days(9)).timestamp(); let upper = now.add(chrono::Duration::days(10)).timestamp(); - match User::find_all(&data).await { + match User::find_all().await { Ok(users) => { // TODO convert into proper query directly after hiqlite migration let users_to_notify = users diff --git a/src/schedulers/src/sessions.rs b/src/schedulers/src/sessions.rs index 1569b4daf..b49342175 100644 --- a/src/schedulers/src/sessions.rs +++ b/src/schedulers/src/sessions.rs @@ -1,14 +1,13 @@ use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; -use rauthy_models::app_state::DbPool; use rauthy_models::database::DB; use std::ops::Sub; use std::time::Duration; use tracing::{debug, error}; // Cleans up old / expired Sessions -pub async fn sessions_cleanup(db: DbPool) { +pub async fn sessions_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(3595 * 2)); loop { @@ -34,7 +33,7 @@ pub async fn sessions_cleanup(db: DbPool) { } } else if let Err(err) = sqlx::query("DELETE FROM sessions WHERE exp < $1") .bind(thres) - .execute(&db) + .execute(DB::conn()) .await { error!("Session Cleanup Error: {:?}", err) diff --git a/src/schedulers/src/tokens.rs b/src/schedulers/src/tokens.rs index f160adad5..cece41935 100644 --- a/src/schedulers/src/tokens.rs +++ b/src/schedulers/src/tokens.rs @@ -1,12 +1,11 @@ use chrono::Utc; use hiqlite::{params, Param}; use rauthy_common::is_hiqlite; -use rauthy_models::app_state::DbPool; use rauthy_models::database::DB; use std::time::Duration; use tracing::{debug, error}; -pub async fn refresh_tokens_cleanup(db: DbPool) { +pub async fn refresh_tokens_cleanup() { let mut interval = tokio::time::interval(Duration::from_secs(3600 * 3)); loop { @@ -30,7 +29,7 @@ pub async fn refresh_tokens_cleanup(db: DbPool) { } } else if let Err(err) = sqlx::query("DELETE FROM refresh_tokens WHERE exp < $1") .bind(now) - .execute(&db) + .execute(DB::conn()) .await { error!("Refresh Token Cleanup Error: {:?}", err) diff --git a/src/schedulers/src/users.rs b/src/schedulers/src/users.rs index b69747037..ad0998f54 100644 --- a/src/schedulers/src/users.rs +++ b/src/schedulers/src/users.rs @@ -1,6 +1,4 @@ -use actix_web::web; use chrono::Utc; -use rauthy_models::app_state::AppState; use rauthy_models::database::DB; use rauthy_models::entity::refresh_tokens::RefreshToken; use rauthy_models::entity::sessions::Session; @@ -9,8 +7,7 @@ use std::env; use std::time::Duration; use tracing::{debug, error, info}; -// Checks for expired users -pub async fn user_expiry_checker(data: web::Data) { +pub async fn user_expiry_checker() { let secs = env::var("SCHED_USER_EXP_MINS") .unwrap_or_else(|_| "60".to_string()) .parse::() @@ -39,7 +36,7 @@ pub async fn user_expiry_checker(data: web::Data) { debug!("Running user_expiry_checker scheduler"); - match User::find_expired(&data).await { + match User::find_expired().await { Ok(users) => { let now = Utc::now().timestamp(); @@ -58,7 +55,7 @@ pub async fn user_expiry_checker(data: web::Data) { }; // invalidate all sessions - if let Err(err) = Session::invalidate_for_user(&data, &user.id).await { + if let Err(err) = Session::invalidate_for_user(&user.id).await { error!( "Error invalidating sessions for user {}: {:?}", user.id, err @@ -66,7 +63,7 @@ pub async fn user_expiry_checker(data: web::Data) { } // invalidate all refresh tokens - if let Err(err) = RefreshToken::invalidate_for_user(&data, &user.id).await { + if let Err(err) = RefreshToken::invalidate_for_user(&user.id).await { error!( "Error invalidating refresh tokens for user {}: {:?}", user.id, err @@ -82,7 +79,7 @@ pub async fn user_expiry_checker(data: web::Data) { user.id, expired_since_secs / 60 ); - if let Err(err) = user.delete(&data).await { + if let Err(err) = user.delete().await { error!( "Error during auto cleanup - deleting user {}: {:?}", user.id, err diff --git a/src/service/src/client.rs b/src/service/src/client.rs index 30de992d8..dce0b5e30 100644 --- a/src/service/src/client.rs +++ b/src/service/src/client.rs @@ -1,20 +1,12 @@ -use actix_web::web; use rauthy_api_types::clients::{ClientSecretResponse, UpdateClientRequest}; use rauthy_error::{ErrorResponse, ErrorResponseType}; -use rauthy_models::app_state::AppState; use rauthy_models::entity::clients::Client; -// Updates a client.
-// A client secret will be automatically generated if the -// [UpdateClientRequest](crate::models::request::UpdateClientRequest) is set to be confidential -// while the currently existing client does not have it. It will be skipped, if it was -// `confidential` already. pub async fn update_client( - data: &web::Data, id: String, client_req: UpdateClientRequest, ) -> Result { - let mut client = Client::find(data, id).await?; + let mut client = Client::find(id).await?; if client.id != client_req.id { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -49,8 +41,8 @@ pub async fn update_client( client.auth_code_lifetime = client_req.auth_code_lifetime; client.access_token_lifetime = client_req.access_token_lifetime; - client.scopes = Client::sanitize_scopes(data, client_req.scopes).await?; - client.default_scopes = Client::sanitize_scopes(data, client_req.default_scopes).await?; + client.scopes = Client::sanitize_scopes(client_req.scopes).await?; + client.default_scopes = Client::sanitize_scopes(client_req.default_scopes).await?; client.challenge = client_req.challenges.map(|c| c.join(",")); client.force_mfa = client_req.force_mfa; @@ -58,16 +50,13 @@ pub async fn update_client( client.contacts = client_req.contacts.map(|c| c.join(",")); client.client_uri = client_req.client_uri; - client.save(data).await?; + client.save().await?; Ok(client) } /// Returns the clients secret in cleartext. -pub async fn get_client_secret( - id: String, - data: &web::Data, -) -> Result { - let client = Client::find(data, id).await?; +pub async fn get_client_secret(id: String) -> Result { + let client = Client::find(id).await?; if !client.confidential { return Err(ErrorResponse::new( @@ -84,18 +73,13 @@ pub async fn get_client_secret( }) } -/// Generates a new client secret and returns it then as clear text wrapped in a -/// [ClientSecretResponse](crate::models::response::ClientSecretResponse) -pub async fn generate_new_secret( - id: String, - data: &web::Data, -) -> Result { - let mut client = Client::find(data, id).await?; +pub async fn generate_new_secret(id: String) -> Result { + let mut client = Client::find(id).await?; let (clear, enc) = Client::generate_new_secret()?; client.confidential = true; client.secret = Some(enc); - client.save(data).await?; + client.save().await?; Ok(ClientSecretResponse { id: client.id, diff --git a/src/service/src/encryption.rs b/src/service/src/encryption.rs index 6e3ee78de..b38084114 100644 --- a/src/service/src/encryption.rs +++ b/src/service/src/encryption.rs @@ -23,7 +23,7 @@ pub async fn migrate_encryption_alg( // migrate clients info!("Starting client secrets migration to key id: {}", new_kid); - let clients = Client::find_all(data) + let clients = Client::find_all() .await? .into_iter() // filter out all clients that already use the new key @@ -41,7 +41,7 @@ pub async fn migrate_encryption_alg( client.secret = Some(enc); client.secret_kid = Some(new_kid.to_string()); - client.save(data).await?; + client.save().await?; modified += 1; } info!("Finished clients secrets migration to key id: {}", new_kid); @@ -51,7 +51,7 @@ pub async fn migrate_encryption_alg( // migrate ApiKey's info!("Starting ApiKeys migration to key id: {}", new_kid); - let api_keys = ApiKeyEntity::find_all(data) + let api_keys = ApiKeyEntity::find_all() .await? .into_iter() // filter out all keys that already use the new key @@ -72,13 +72,13 @@ pub async fn migrate_encryption_alg( api_key.enc_key_id = new_kid.to_string(); - api_key.save(data).await?; + api_key.save().await?; modified += 1; } info!("Finished ApiKeys migration to key id: {}", new_kid); // migrate auth providers - let providers = AuthProvider::find_all(data).await?; + let providers = AuthProvider::find_all().await?; for mut provider in providers { match AuthProvider::get_secret_cleartext(&provider.secret) { Ok(plaintext_opt) => { @@ -88,7 +88,7 @@ pub async fn migrate_encryption_alg( .into_bytes() .to_vec(), ); - provider.save(data).await?; + provider.save().await?; modified += 1; }; diff --git a/src/service/src/oidc/authorize.rs b/src/service/src/oidc/authorize.rs index fa9858328..64c65fbb2 100644 --- a/src/service/src/oidc/authorize.rs +++ b/src/service/src/oidc/authorize.rs @@ -26,17 +26,15 @@ pub async fn post_authorize( add_login_delay: &mut bool, user_needs_mfa: &mut bool, ) -> Result { - let mut user = User::find_by_email(data, req_data.email) - .await - .inspect_err(|_| { - // The UI does not show the password input form when there is no user yet. - // To prevent username enumeration, we should not add a login delay if a user does not - // even exist, when the UI is in that phase where the user does not provide any - // password. - if req_data.password.is_none() { - *add_login_delay = false; - } - })?; + let mut user = User::find_by_email(req_data.email).await.inspect_err(|_| { + // The UI does not show the password input form when there is no user yet. + // To prevent username enumeration, we should not add a login delay if a user does not + // even exist, when the UI is in that phase where the user does not provide any + // password. + if req_data.password.is_none() { + *add_login_delay = false; + } + })?; let mfa_cookie = if let Ok(c) = WebauthnCookie::parse_validate(&ApiCookie::from_req(req, COOKIE_MFA)) { @@ -94,11 +92,11 @@ pub async fn post_authorize( user.last_login = Some(Utc::now().timestamp()); user.last_failed_login = None; user.failed_login_attempts = None; - user.save(data, None).await?; + user.save(None).await?; } // client validations - let client = Client::find_maybe_ephemeral(data, req_data.client_id).await?; + let client = Client::find_maybe_ephemeral(req_data.client_id).await?; client.validate_mfa(&user).inspect_err(|_| { // in this case, we do not want to add a login delay // the user password was correct, we only need a passkey being added to the account @@ -138,7 +136,7 @@ pub async fn post_authorize( // TODO should we allow to skip this step if set so in the config? // check if we need to validate the 2nd factor if user.has_webauthn_enabled() { - session.set_mfa(data, true).await?; + session.set_mfa(true).await?; let step = AuthStepAwaitWebauthn { code: get_rand(48), @@ -175,7 +173,6 @@ pub async fn post_authorize( } pub async fn post_authorize_refresh( - data: &web::Data, session: &Session, client: Client, header_origin: Option<(HeaderName, HeaderValue)>, @@ -187,7 +184,7 @@ pub async fn post_authorize_refresh( "No linked user_id for already validated session", ) })?; - let user = User::find(data, user_id.clone()).await?; + let user = User::find(user_id.clone()).await?; user.check_enabled()?; user.check_expired()?; diff --git a/src/service/src/oidc/grant_types/authorization_code.rs b/src/service/src/oidc/grant_types/authorization_code.rs index 194f51d9e..99d732231 100644 --- a/src/service/src/oidc/grant_types/authorization_code.rs +++ b/src/service/src/oidc/grant_types/authorization_code.rs @@ -38,7 +38,7 @@ pub async fn grant_type_authorization_code( // check the client for external origin and oidc flow let (client_id, client_secret) = req_data.try_get_client_id_secret(&req)?; - let client = Client::find_maybe_ephemeral(data, client_id.clone()) + let client = Client::find_maybe_ephemeral(client_id.clone()) .await .map_err(|_| { ErrorResponse::new( @@ -143,7 +143,7 @@ pub async fn grant_type_authorization_code( // // An additional check at this point does not provide any security benefit but only uses resources. - let user = User::find(data, code.user_id.clone()).await?; + let user = User::find(code.user_id.clone()).await?; let token_set = TokenSet::from_user( &user, data, @@ -160,7 +160,7 @@ pub async fn grant_type_authorization_code( // update session metadata if code.session_id.is_some() { let sid = code.session_id.as_ref().unwrap().clone(); - let mut session = Session::find(data, sid).await?; + let mut session = Session::find(sid).await?; session.last_seen = Utc::now().timestamp(); session.state = SessionState::Auth.as_str().to_string(); @@ -172,13 +172,13 @@ pub async fn grant_type_authorization_code( session.user_id = Some(user.id); session.roles = Some(user.roles); session.groups = user.groups; - session.save(data).await?; + session.save().await?; } code.delete().await?; // update timestamp if it is a dynamic client if client.is_dynamic() { - ClientDyn::update_used(data, &client.id).await?; + ClientDyn::update_used(&client.id).await?; } Ok((token_set, headers)) diff --git a/src/service/src/oidc/grant_types/client_credentials.rs b/src/service/src/oidc/grant_types/client_credentials.rs index e13af47c5..1a35a68b3 100644 --- a/src/service/src/oidc/grant_types/client_credentials.rs +++ b/src/service/src/oidc/grant_types/client_credentials.rs @@ -24,7 +24,7 @@ pub async fn grant_type_credentials( } let (client_id, client_secret) = req_data.try_get_client_id_secret(&req)?; - let client = Client::find(data, client_id).await?; + let client = Client::find(client_id).await?; if !client.confidential { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -62,7 +62,7 @@ pub async fn grant_type_credentials( // update timestamp if it is a dynamic client if client.is_dynamic() { - ClientDyn::update_used(data, &client.id).await?; + ClientDyn::update_used(&client.id).await?; } let ts = TokenSet::for_client_credentials(data, &client, dpop_fingerprint).await?; diff --git a/src/service/src/oidc/grant_types/device_code.rs b/src/service/src/oidc/grant_types/device_code.rs index b4c788ff5..9a6dfbcc6 100644 --- a/src/service/src/oidc/grant_types/device_code.rs +++ b/src/service/src/oidc/grant_types/device_code.rs @@ -95,7 +95,7 @@ pub async fn grant_type_device_code( // check validation if let Some(verified_by) = &code.verified_by { - let user = match User::find(data, verified_by.clone()).await { + let user = match User::find(verified_by.clone()).await { Ok(user) => user, Err(err) => { // at this point, this should never fail - only if the DB went down in the meantime @@ -107,7 +107,7 @@ pub async fn grant_type_device_code( } }; - let client = match Client::find(data, code.client_id.clone()).await { + let client = match Client::find(code.client_id.clone()).await { Ok(client) => client, Err(err) => { // at this point, this should never fail - only if the DB went down in the meantime @@ -151,7 +151,7 @@ pub async fn grant_type_device_code( // TODO add an optional `name` param to the initial device request? name: id.clone(), }; - if let Err(err) = device.insert(data).await { + if let Err(err) = device.insert().await { error!("{:?}", err); return HttpResponse::InternalServerError().json(OAuth2ErrorResponse { error: OAuth2ErrorTypeResponse::InvalidRequest, diff --git a/src/service/src/oidc/grant_types/password.rs b/src/service/src/oidc/grant_types/password.rs index 169d80dd0..b4e6f41cb 100644 --- a/src/service/src/oidc/grant_types/password.rs +++ b/src/service/src/oidc/grant_types/password.rs @@ -38,7 +38,7 @@ pub async fn grant_type_password( let email = req_data.username.as_ref().unwrap(); let password = req_data.password.unwrap(); - let client = Client::find(data, client_id).await?; + let client = Client::find(client_id).await?; let header_origin = client.validate_origin(&req, &data.listen_scheme, &data.public_url)?; if client.confidential { let secret = client_secret.ok_or_else(|| { @@ -67,7 +67,7 @@ pub async fn grant_type_password( // This Error must be the same if user does not exist AND passwords do not match to prevent // username enumeration - let mut user = User::find_by_email(data, String::from(email)).await?; + let mut user = User::find_by_email(String::from(email)).await?; user.check_enabled()?; user.check_expired()?; @@ -86,11 +86,11 @@ pub async fn grant_type_password( user.password = Some(new_hash); } - user.save(data, None).await?; + user.save(None).await?; // update timestamp if it is a dynamic client if client.is_dynamic() { - ClientDyn::update_used(data, &client.id).await?; + ClientDyn::update_used(&client.id).await?; } let ts = TokenSet::from_user( @@ -117,7 +117,7 @@ pub async fn grant_type_password( user.last_failed_login = Some(Utc::now().timestamp()); user.failed_login_attempts = Some(&user.failed_login_attempts.unwrap_or(0) + 1); - user.save(data, None).await?; + user.save(None).await?; // TODO add expo increasing sleeps after failed login attempts here? Err(err) diff --git a/src/service/src/oidc/grant_types/refresh_token.rs b/src/service/src/oidc/grant_types/refresh_token.rs index 14234ea23..bd3b6f3a3 100644 --- a/src/service/src/oidc/grant_types/refresh_token.rs +++ b/src/service/src/oidc/grant_types/refresh_token.rs @@ -22,7 +22,7 @@ pub async fn grant_type_refresh( )); } let (client_id, client_secret) = req_data.try_get_client_id_secret(&req)?; - let client = Client::find_maybe_ephemeral(data, client_id).await?; + let client = Client::find_maybe_ephemeral(client_id).await?; let header_origin = client.validate_origin(&req, &data.listen_scheme, &data.public_url)?; diff --git a/src/service/src/oidc/logout.rs b/src/service/src/oidc/logout.rs index 4d4c52712..a4f8cf115 100644 --- a/src/service/src/oidc/logout.rs +++ b/src/service/src/oidc/logout.rs @@ -17,7 +17,7 @@ pub async fn get_logout_html( data: &web::Data, lang: &Language, ) -> Result { - let colors = ColorEntity::find_rauthy(data).await?; + let colors = ColorEntity::find_rauthy().await?; if logout_request.id_token_hint.is_none() { return Ok(LogoutHtml::build(&session.csrf_token, false, &colors, lang)); @@ -39,7 +39,7 @@ pub async fn get_logout_html( if logout_request.post_logout_redirect_uri.is_some() { // unwrap is safe since the token is valid already let client_id = claims.custom.azp; - let client = Client::find(data, client_id).await?; + let client = Client::find(client_id).await?; if client.post_logout_redirect_uris.is_none() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, diff --git a/src/service/src/oidc/token_info.rs b/src/service/src/oidc/token_info.rs index 0d7b67e69..dc3b622b1 100644 --- a/src/service/src/oidc/token_info.rs +++ b/src/service/src/oidc/token_info.rs @@ -73,7 +73,7 @@ async fn check_client_auth( }; let header = header_value.to_str().unwrap_or_default(); - let client = Client::find(data, client_id).await.map_err(|_| { + let client = Client::find(client_id).await.map_err(|_| { ErrorResponse::new( ErrorResponseType::WWWAuthenticate("client-not-found".to_string()), "client does not exist anymore".to_string(), diff --git a/src/service/src/oidc/userinfo.rs b/src/service/src/oidc/userinfo.rs index c77cc647b..ba89b0885 100644 --- a/src/service/src/oidc/userinfo.rs +++ b/src/service/src/oidc/userinfo.rs @@ -34,7 +34,7 @@ pub async fn get_userinfo( "Token without 'sub' - could not extract the Principal", ) })?; - let user = User::find(data, uid).await.map_err(|_| { + let user = User::find(uid).await.map_err(|_| { ErrorResponse::new( ErrorResponseType::WWWAuthenticate("user-not-found".to_string()), "The user has not been found".to_string(), @@ -53,7 +53,7 @@ pub async fn get_userinfo( // if the token has been issued to a device, make sure it still exists and is valid if let Some(device_id) = claims.custom.did { // just make sure it still exists - DeviceEntity::find(data, &device_id).await.map_err(|_| { + DeviceEntity::find(&device_id).await.map_err(|_| { ErrorResponse::new( ErrorResponseType::WWWAuthenticate("user-device-not-found".to_string()), "The user device has not been found".to_string(), @@ -65,7 +65,7 @@ pub async fn get_userinfo( // skip this check if the client is ephemeral if !(claims.custom.azp.starts_with("http://") || claims.custom.azp.starts_with("https://")) { - let client = Client::find(data, claims.custom.azp).await.map_err(|_| { + let client = Client::find(claims.custom.azp).await.map_err(|_| { ErrorResponse::new( ErrorResponseType::WWWAuthenticate("client-not-found".to_string()), "The client has not been found".to_string(), @@ -131,7 +131,7 @@ pub async fn get_userinfo( userinfo.family_name = Some(user.family_name.clone()); userinfo.locale = Some(user.language.to_string()); - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; user_values_fetched = true; if let Some(values) = &user_values { @@ -143,7 +143,7 @@ pub async fn get_userinfo( if scope.contains("address") { if !user_values_fetched { - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; user_values_fetched = true; } @@ -154,7 +154,7 @@ pub async fn get_userinfo( if scope.contains("phone") { if !user_values_fetched { - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; // user_values_fetched = true; } diff --git a/src/service/src/oidc/validation.rs b/src/service/src/oidc/validation.rs index 9408e8125..74594b28a 100644 --- a/src/service/src/oidc/validation.rs +++ b/src/service/src/oidc/validation.rs @@ -30,7 +30,7 @@ pub async fn validate_auth_req_param( code_challenge_method: &Option, ) -> Result<(Client, Option<(HeaderName, HeaderValue)>), ErrorResponse> { // client exists - let client = Client::find_maybe_ephemeral(data, String::from(client_id)).await?; + let client = Client::find_maybe_ephemeral(String::from(client_id)).await?; // allowed origin let header = client.validate_origin(req, &data.listen_scheme, &data.public_url)?; @@ -75,7 +75,7 @@ pub async fn validate_token ::serde::Deserialize< let kid = JwkKeyPair::kid_from_token(token)?; // retrieve jwk for kid - let kp = JwkKeyPair::find(data, kid).await?; + let kp = JwkKeyPair::find(kid).await?; validate_jwt!(T, kp, token, options) // TODO check roles if we add more users / roles @@ -99,7 +99,7 @@ pub async fn validate_refresh_token( let kid = JwkKeyPair::kid_from_token(refresh_token)?; // retrieve jwk for kid - let kp = JwkKeyPair::find(data, kid).await?; + let kp = JwkKeyPair::find(kid).await?; let claims: JWTClaims = validate_jwt!(JwtRefreshClaims, kp, refresh_token, options)?; @@ -118,7 +118,7 @@ pub async fn validate_refresh_token( let client = if let Some(c) = client_opt { c } else { - Client::find(data, claims.custom.azp.clone()).await? + Client::find(claims.custom.azp.clone()).await? }; if client.id != claims.custom.azp { return Err(ErrorResponse::new( @@ -151,7 +151,7 @@ pub async fn validate_refresh_token( (None, None) }; - let mut user = User::find(data, uid).await?; + let mut user = User::find(uid).await?; user.check_enabled()?; user.check_expired()?; @@ -160,7 +160,7 @@ pub async fn validate_refresh_token( let now = Utc::now().timestamp(); let exp_at_secs = now + data.refresh_grace_time as i64; let rt_scope = if let Some(device_id) = &claims.custom.did { - let mut rt = RefreshTokenDevice::find(data, validation_str).await?; + let mut rt = RefreshTokenDevice::find(validation_str).await?; if &rt.device_id != device_id { return Err(ErrorResponse::new( @@ -177,14 +177,14 @@ pub async fn validate_refresh_token( if rt.exp > exp_at_secs + 1 { rt.exp = exp_at_secs; - rt.save(data).await?; + rt.save().await?; } rt.scope } else { - let mut rt = RefreshToken::find(data, validation_str).await?; + let mut rt = RefreshToken::find(validation_str).await?; if rt.exp > exp_at_secs + 1 { rt.exp = exp_at_secs; - rt.save(data).await?; + rt.save().await?; } rt.scope }; @@ -194,7 +194,7 @@ pub async fn validate_refresh_token( // set last login user.last_login = Some(Utc::now().timestamp()); - user.save(data, None).await?; + user.save(None).await?; let auth_time = if let Some(ts) = claims.custom.auth_time { AuthTime::given(ts) diff --git a/src/service/src/password_reset.rs b/src/service/src/password_reset.rs index 89b67d47c..4baa2456e 100644 --- a/src/service/src/password_reset.rs +++ b/src/service/src/password_reset.rs @@ -21,20 +21,19 @@ use rauthy_models::templates::PwdResetHtml; use tracing::{debug, error}; pub async fn handle_get_pwd_reset<'a>( - data: &web::Data, req: HttpRequest, user_id: String, reset_id: String, no_html: bool, ) -> Result<(String, cookie::Cookie<'a>), ErrorResponse> { - let mut ml = MagicLink::find(data, &reset_id).await?; + let mut ml = MagicLink::find(&reset_id).await?; ml.validate(&user_id, &req, false)?; - let user = User::find(data, ml.user_id.clone()).await?; + let user = User::find(ml.user_id.clone()).await?; // get the html and insert values - let rules = PasswordPolicy::find(data).await?; - let colors = ColorEntity::find_rauthy(data).await?; + let rules = PasswordPolicy::find().await?; + let colors = ColorEntity::find_rauthy().await?; let lang = Language::try_from(&req).unwrap_or_default(); let content = if no_html { @@ -52,7 +51,7 @@ pub async fn handle_get_pwd_reset<'a>( // generate a cookie value and save it to the magic link let cookie_val = get_rand(48); ml.cookie = Some(cookie_val); - ml.save(data).await?; + ml.save().await?; let age_secs = ml.exp - Utc::now().timestamp(); let cookie = ApiCookie::build(PWD_RESET_COOKIE, ml.cookie.unwrap(), age_secs); @@ -69,12 +68,12 @@ pub async fn handle_put_user_passkey_start<'a>( ) -> Result { // validate user_id / given email address debug!("getting user"); - let user = User::find(data, user_id).await?; + let user = User::find(user_id).await?; debug!("getting magic link"); // unwrap is safe -> checked in API endpoint already let ml_id = req_data.magic_link_id.as_ref().unwrap(); - let ml = MagicLink::find(data, ml_id).await?; + let ml = MagicLink::find(ml_id).await?; ml.validate(&user.id, &req, true)?; // if we register a new passkey, we need to make sure that the magic link is for a new user @@ -102,7 +101,7 @@ pub async fn handle_put_user_passkey_finish<'a>( ) -> Result { // unwrap is safe -> checked in API endpoint already let ml_id = req_data.magic_link_id.as_ref().unwrap(); - let mut ml = MagicLink::find(data, ml_id).await?; + let mut ml = MagicLink::find(ml_id).await?; ml.validate(&user_id, &req, true)?; // finish webauthn request -> always force UV for passkey only accounts @@ -129,8 +128,8 @@ pub async fn handle_put_user_passkey_finish<'a>( debug!("invalidating magic link pwd"); // all good - ml.invalidate(data).await?; - User::set_email_verified(data, user_id, true).await?; + ml.invalidate().await?; + User::set_email_verified(user_id, true).await?; // delete the cookie let cookie = ApiCookie::build(PWD_RESET_COOKIE, "", 0); @@ -146,7 +145,7 @@ pub async fn handle_put_user_password_reset<'a>( req_data: PasswordResetRequest, ) -> Result<(cookie::Cookie<'a>, Option), ErrorResponse> { // validate user_id - let mut user = User::find(data, user_id).await?; + let mut user = User::find(user_id).await?; // check MFA code if user.has_webauthn_enabled() { @@ -173,16 +172,16 @@ pub async fn handle_put_user_password_reset<'a>( } } - let mut ml = MagicLink::find(data, &req_data.magic_link_id).await?; + let mut ml = MagicLink::find(&req_data.magic_link_id).await?; ml.validate(&user.id, &req, true)?; // validate password - user.apply_password_rules(data, &req_data.password).await?; + user.apply_password_rules(&req_data.password).await?; // all good - ml.invalidate(data).await?; + ml.invalidate().await?; user.email_verified = true; - user.save(data, None).await?; + user.save(None).await?; let ip = match real_ip_from_req(&req).ok() { None => { @@ -200,7 +199,7 @@ pub async fn handle_put_user_password_reset<'a>( .unwrap(); // delete all existing user sessions to have a clean flow - Session::invalidate_for_user(data, &user.id).await?; + Session::invalidate_for_user(&user.id).await?; // check if we got a custom `redirect_uri` during registration let redirect_uri = match MagicLinkUsage::try_from(&ml.usage)? { diff --git a/src/service/src/token_set.rs b/src/service/src/token_set.rs index cd7743ec3..38e5a6527 100644 --- a/src/service/src/token_set.rs +++ b/src/service/src/token_set.rs @@ -204,7 +204,7 @@ impl TokenSet { // sign the token let key_pair_alg = JwkKeyPairAlg::from_str(&client.access_token_alg)?; - let kp = JwkKeyPair::find_latest(data, key_pair_alg).await?; + let kp = JwkKeyPair::find_latest(key_pair_alg).await?; sign_jwt!(kp, claims) } @@ -267,7 +267,7 @@ impl TokenSet { custom_claims.family_name = Some(user.family_name.clone()); custom_claims.locale = Some(user.language.to_string()); - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; user_values_fetched = true; if let Some(values) = &user_values { @@ -279,7 +279,7 @@ impl TokenSet { if scope.contains("address") { if !user_values_fetched { - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; user_values_fetched = true; } @@ -290,7 +290,7 @@ impl TokenSet { if scope.contains("phone") { if !user_values_fetched { - user_values = UserValues::find(data, &user.id).await?; + user_values = UserValues::find(&user.id).await?; // user_values_fetched = true; } @@ -349,7 +349,7 @@ impl TokenSet { // sign the token let key_pair_alg = JwkKeyPairAlg::from_str(&client.id_token_alg)?; - let kp = JwkKeyPair::find_latest(data, key_pair_alg).await?; + let kp = JwkKeyPair::find_latest(key_pair_alg).await?; sign_jwt!(kp, claims) } @@ -393,7 +393,7 @@ impl TokenSet { // sign the token let token = { - let kp = JwkKeyPair::find_latest(data, JwkKeyPairAlg::default()).await?; + let kp = JwkKeyPair::find_latest(JwkKeyPairAlg::default()).await?; sign_jwt!(kp, claims) }?; @@ -405,7 +405,6 @@ impl TokenSet { *DEVICE_GRANT_REFRESH_TOKEN_LIFETIME as i64, )); RefreshTokenDevice::create( - data, validation_string, device_id, user.id.clone(), @@ -417,7 +416,6 @@ impl TokenSet { } else { let exp = nbf.add(chrono::Duration::hours(*REFRESH_TOKEN_LIFETIME as i64)); RefreshToken::create( - data, validation_string, user.id.clone(), nbf, @@ -488,7 +486,7 @@ impl TokenSet { let scps; let attrs; let (customs_access, customs_id) = if !cust.is_empty() { - scps = Some(Scope::find_all(data).await?); + scps = Some(Scope::find_all().await?); let mut customs_access = Vec::with_capacity(cust.len()); let mut customs_id = Vec::with_capacity(cust.len()); @@ -506,7 +504,7 @@ impl TokenSet { // if there was any custom mapping, we need the additional user attributes attrs = if !customs_access.is_empty() || !customs_id.is_empty() { - let attrs = UserAttrValueEntity::find_for_user(data, &user.id).await?; + let attrs = UserAttrValueEntity::find_for_user(&user.id).await?; let mut res = HashMap::with_capacity(attrs.len()); attrs.iter().for_each(|a| { res.insert(a.key.clone(), a.value.clone());