Skip to content

Commit

Permalink
refactor: constrain sql input
Browse files Browse the repository at this point in the history
This foregos a SQL enum in favor of constraining
the input to the range of the protobuf enum,
this allows the same idea as an enum but
works around some limitations with `prost`
using `#[repr(i32)]` on its enums which messes
with `sqlx`'s derives.
  • Loading branch information
Zoe Spellman committed Feb 20, 2024
1 parent 27d218e commit 739f664
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 35 deletions.
33 changes: 23 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ edition = "2021"
rust-version = "1.75.0"

[dependencies]
convert_case = "0.6.0"
dotenv = "0.15"
envy = "0.4"
futures = "0.3"
Expand All @@ -28,6 +27,7 @@ sqlx = { version = "0.7", features = [
"postgres",
"time",
] }
strum = { version = "0.26.1", features = ["derive"] }
thiserror = "1.0"
time = { version = "0.3", features = ["macros"] }
tokio = { version = "1.36", features = ["full"] }
Expand Down
5 changes: 5 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.build_server(true)
.file_descriptor_set_path(descriptor_set_path)
.out_dir(out_dir)
.type_attribute("Category", "#[derive(sqlx::Type, strum::EnumString)]")
.type_attribute(
"Category",
r#"#[strum(serialize_all = "kebab_case", ascii_case_insensitive)]"#,
)
.compile(files, &["proto"])?;

Ok(())
Expand Down
3 changes: 2 additions & 1 deletion sql/migrations/20240208071037_categories.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
CREATE TABLE snap_categories (
id SERIAL PRIMARY KEY,
snap_id CHAR(32) NOT NULL,
category VARCHAR NOT NULL
category INTEGER NOT NULL,
CONSTRAINT category CHECK (category BETWEEN 0 AND 19)
);
2 changes: 1 addition & 1 deletion src/features/chart/infrastructure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(crate) async fn get_votes_summary(
SELECT snap_categories.snap_id FROM snap_categories
WHERE snap_categories.category = "#,
)
.push_bind(category.to_kebab_case())
.push_bind(category)
.push(")");
}

Expand Down
29 changes: 13 additions & 16 deletions src/features/user/infrastructure.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
//! Infrastructure for user handling
use snapd::{
api::{
convenience::SnapNameFromId,
find::{CategoryName, FindSnapByName},
},
api::{convenience::SnapNameFromId, find::FindSnapByName},
SnapdClient,
};
use sqlx::{Acquire, Executor, Row};
use tracing::error;

use crate::{
app::AppContext,
features::user::{
entities::{User, Vote},
errors::UserError,
features::{
pb::chart::Category,
user::{
entities::{User, Vote},
errors::UserError,
},
},
};

Expand Down Expand Up @@ -190,13 +190,13 @@ pub(crate) async fn save_vote_to_db(app_ctx: &AppContext, vote: Vote) -> Result<
async fn snapd_categories_by_snap_id(
client: &SnapdClient,
snap_id: &str,
) -> Result<Vec<CategoryName<'static>>, UserError> {
) -> Result<Vec<Category>, UserError> {
let snap_name = SnapNameFromId::get_name(snap_id.into(), client).await?;

Ok(FindSnapByName::get_categories(snap_name, client)
.await?
.into_iter()
.map(|v| v.name)
.map(|v| Category::try_from(v.name.as_ref()).expect("got unknown category?"))
.collect())
}

Expand Down Expand Up @@ -227,9 +227,9 @@ pub(crate) async fn update_category(app_ctx: &AppContext, snap_id: &str) -> Resu

for category in categories.iter() {
tx.execute(
sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1,$2); ")
sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1, $2); ")
.bind(snap_id)
.bind(category.as_ref()),
.bind(category),
)
.await?;
}
Expand Down Expand Up @@ -302,7 +302,7 @@ pub(crate) async fn find_user_votes(
mod test {
use std::collections::HashSet;

use snapd::{api::find::CategoryName, SnapdClient};
use snapd::SnapdClient;

use crate::features::pb::chart::Category;

Expand All @@ -317,10 +317,7 @@ mod test {
.unwrap();

assert_eq!(
TESTING_SNAP_CATEGORIES
.map(|v| CategoryName::from(v.to_kebab_case()))
.into_iter()
.collect::<HashSet<_>>(),
TESTING_SNAP_CATEGORIES.into_iter().collect::<HashSet<_>>(),
categories.into_iter().collect::<HashSet<_>>()
)
}
Expand Down
9 changes: 3 additions & 6 deletions tests/user_tests/category.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ async fn vote_sets_category(
conn: &mut PoolConnection<Postgres>,
expected: &HashSet<Category>,
) {
let expected: HashSet<_> = expected.iter().map(|v| v.to_kebab_case()).collect();
let result = sqlx::query(
r#"
SELECT snap_categories.category
Expand All @@ -143,13 +142,11 @@ async fn vote_sets_category(
.fetch(&mut **conn)
.map(|row| {
row.expect("error when retrieving row")
.try_get::<String, _>("category")
.try_get::<Category, _>("category")
.expect("could not get category field")
.to_lowercase()
})
.collect::<HashSet<String>>()
.collect::<HashSet<_>>()
.await;

assert_eq!(result, expected);
assert_eq!(&result, expected);
}
// 3Iwi803Tk3KQwyD6jFiAJdlq8MLgBIoD

0 comments on commit 739f664

Please sign in to comment.