From 739f664599fbf1ca561c6cb9f529551ce38d9805 Mon Sep 17 00:00:00 2001 From: Zoe Spellman Date: Wed, 14 Feb 2024 22:38:12 -0800 Subject: [PATCH] refactor: constrain sql input This foregos a SQL enum in favor of constraining the input to the range of the protobuf enum, this allows the same idea as an enum but works around some limitations with `prost` using `#[repr(i32)]` on its enums which messes with `sqlx`'s derives. --- Cargo.lock | 33 +++++++++++++------ Cargo.toml | 2 +- build.rs | 5 +++ .../20240208071037_categories.up.sql | 3 +- src/features/chart/infrastructure.rs | 2 +- src/features/user/infrastructure.rs | 29 ++++++++-------- tests/user_tests/category.rs | 9 ++--- 7 files changed, 48 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7ea3af51..8ac9772c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -244,15 +244,6 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" -[[package]] -name = "convert_case" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "core-foundation" version = "0.9.4" @@ -1557,7 +1548,6 @@ dependencies = [ name = "ratings" version = "0.2.0" dependencies = [ - "convert_case", "dotenv", "envy", "futures", @@ -1573,6 +1563,7 @@ dependencies = [ "sha2", "snapd", "sqlx", + "strum", "thiserror", "time", "tokio", @@ -2226,6 +2217,28 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strum" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "723b93e8addf9aa965ebe2d11da6d7540fa2283fcea14b3371ff055f7ba13f5f" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.48", +] + [[package]] name = "subtle" version = "2.5.0" diff --git a/Cargo.toml b/Cargo.toml index 734e01e0..4d19ecb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ edition = "2021" rust-version = "1.75.0" [dependencies] -convert_case = "0.6.0" dotenv = "0.15" envy = "0.4" futures = "0.3" @@ -28,6 +27,7 @@ sqlx = { version = "0.7", features = [ "postgres", "time", ] } +strum = { version = "0.26.1", features = ["derive"] } thiserror = "1.0" time = { version = "0.3", features = ["macros"] } tokio = { version = "1.36", features = ["full"] } diff --git a/build.rs b/build.rs index 2bb6c963..ac3dedd1 100644 --- a/build.rs +++ b/build.rs @@ -21,6 +21,11 @@ fn main() -> Result<(), Box> { .build_server(true) .file_descriptor_set_path(descriptor_set_path) .out_dir(out_dir) + .type_attribute("Category", "#[derive(sqlx::Type, strum::EnumString)]") + .type_attribute( + "Category", + r#"#[strum(serialize_all = "kebab_case", ascii_case_insensitive)]"#, + ) .compile(files, &["proto"])?; Ok(()) diff --git a/sql/migrations/20240208071037_categories.up.sql b/sql/migrations/20240208071037_categories.up.sql index ec5a8ddc..145acedd 100644 --- a/sql/migrations/20240208071037_categories.up.sql +++ b/sql/migrations/20240208071037_categories.up.sql @@ -3,5 +3,6 @@ CREATE TABLE snap_categories ( id SERIAL PRIMARY KEY, snap_id CHAR(32) NOT NULL, - category VARCHAR NOT NULL + category INTEGER NOT NULL, + CONSTRAINT category CHECK (category BETWEEN 0 AND 19) ); diff --git a/src/features/chart/infrastructure.rs b/src/features/chart/infrastructure.rs index d6d33530..17bded77 100644 --- a/src/features/chart/infrastructure.rs +++ b/src/features/chart/infrastructure.rs @@ -52,7 +52,7 @@ pub(crate) async fn get_votes_summary( SELECT snap_categories.snap_id FROM snap_categories WHERE snap_categories.category = "#, ) - .push_bind(category.to_kebab_case()) + .push_bind(category) .push(")"); } diff --git a/src/features/user/infrastructure.rs b/src/features/user/infrastructure.rs index 6c774e8a..a50c7769 100644 --- a/src/features/user/infrastructure.rs +++ b/src/features/user/infrastructure.rs @@ -1,9 +1,6 @@ //! Infrastructure for user handling use snapd::{ - api::{ - convenience::SnapNameFromId, - find::{CategoryName, FindSnapByName}, - }, + api::{convenience::SnapNameFromId, find::FindSnapByName}, SnapdClient, }; use sqlx::{Acquire, Executor, Row}; @@ -11,9 +8,12 @@ use tracing::error; use crate::{ app::AppContext, - features::user::{ - entities::{User, Vote}, - errors::UserError, + features::{ + pb::chart::Category, + user::{ + entities::{User, Vote}, + errors::UserError, + }, }, }; @@ -190,13 +190,13 @@ pub(crate) async fn save_vote_to_db(app_ctx: &AppContext, vote: Vote) -> Result< async fn snapd_categories_by_snap_id( client: &SnapdClient, snap_id: &str, -) -> Result>, UserError> { +) -> Result, UserError> { let snap_name = SnapNameFromId::get_name(snap_id.into(), client).await?; Ok(FindSnapByName::get_categories(snap_name, client) .await? .into_iter() - .map(|v| v.name) + .map(|v| Category::try_from(v.name.as_ref()).expect("got unknown category?")) .collect()) } @@ -227,9 +227,9 @@ pub(crate) async fn update_category(app_ctx: &AppContext, snap_id: &str) -> Resu for category in categories.iter() { tx.execute( - sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1,$2); ") + sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1, $2); ") .bind(snap_id) - .bind(category.as_ref()), + .bind(category), ) .await?; } @@ -302,7 +302,7 @@ pub(crate) async fn find_user_votes( mod test { use std::collections::HashSet; - use snapd::{api::find::CategoryName, SnapdClient}; + use snapd::SnapdClient; use crate::features::pb::chart::Category; @@ -317,10 +317,7 @@ mod test { .unwrap(); assert_eq!( - TESTING_SNAP_CATEGORIES - .map(|v| CategoryName::from(v.to_kebab_case())) - .into_iter() - .collect::>(), + TESTING_SNAP_CATEGORIES.into_iter().collect::>(), categories.into_iter().collect::>() ) } diff --git a/tests/user_tests/category.rs b/tests/user_tests/category.rs index 593f07d3..2e55ca04 100644 --- a/tests/user_tests/category.rs +++ b/tests/user_tests/category.rs @@ -131,7 +131,6 @@ async fn vote_sets_category( conn: &mut PoolConnection, expected: &HashSet, ) { - let expected: HashSet<_> = expected.iter().map(|v| v.to_kebab_case()).collect(); let result = sqlx::query( r#" SELECT snap_categories.category @@ -143,13 +142,11 @@ async fn vote_sets_category( .fetch(&mut **conn) .map(|row| { row.expect("error when retrieving row") - .try_get::("category") + .try_get::("category") .expect("could not get category field") - .to_lowercase() }) - .collect::>() + .collect::>() .await; - assert_eq!(result, expected); + assert_eq!(&result, expected); } -// 3Iwi803Tk3KQwyD6jFiAJdlq8MLgBIoD