Skip to content

Commit

Permalink
server: improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
conradoplg committed Dec 26, 2024
1 parent f5cb068 commit 33f8017
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 100 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
xeddsa = "1.0.2"
futures-util = "0.3.31"
futures = "0.3.31"
thiserror = "2.0.3"
hex = "0.4.3"

[dev-dependencies]
Expand Down
101 changes: 33 additions & 68 deletions server/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum::{extract::State, http::StatusCode, Json};
use eyre::eyre;
use axum::{extract::State, Json};
use uuid::Uuid;
use xeddsa::{xed25519, Verify as _};

Expand Down Expand Up @@ -33,35 +32,21 @@ pub(crate) async fn login(
) -> Result<Json<KeyLoginOutput>, AppError> {
// Check if the user sent the credentials
if args.signature.is_empty() || args.pubkey.is_empty() {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("empty args").into(),
));
return Err(AppError::InvalidArgument("signature or pubkey".into()));
}

let pubkey = TryInto::<[u8; 32]>::try_into(args.pubkey.clone()).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid pubkey").into(),
)
})?;
let pubkey = TryInto::<[u8; 32]>::try_into(args.pubkey.clone())
.map_err(|_| AppError::InvalidArgument("pubkey".into()))?;
let pubkey = xed25519::PublicKey(pubkey);
let signature = TryInto::<[u8; 64]>::try_into(args.signature).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid signature").into(),
)
})?;
let signature = TryInto::<[u8; 64]>::try_into(args.signature)
.map_err(|_| AppError::InvalidArgument("signature".into()))?;
pubkey
.verify(args.uuid.as_bytes(), &signature)
.map_err(|_| AppError(StatusCode::UNAUTHORIZED, eyre!("invalid signature").into()))?;
.map_err(|_| AppError::Unauthorized)?;

let mut challenges = state.challenges.write().unwrap();
if !challenges.remove(&args.uuid) {
return Err(AppError(
StatusCode::UNAUTHORIZED,
eyre!("invalid challenge").into(),
));
return Err(AppError::Unauthorized);
}
drop(challenges);

Expand Down Expand Up @@ -97,10 +82,7 @@ pub(crate) async fn create_new_session(
Json(args): Json<CreateNewSessionArgs>,
) -> Result<Json<CreateNewSessionOutput>, AppError> {
if args.message_count == 0 {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid message_count").into(),
));
return Err(AppError::InvalidArgument("message_count".into()));
}

// Create new session object.
Expand Down Expand Up @@ -157,22 +139,17 @@ pub(crate) async fn get_session_info(
let sessions = state.sessions.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError::SessionNotFound)?;

if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
return Err(AppError::SessionNotFound);
}

let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

Ok(Json(GetSessionInfoOutput {
num_signers: session.num_signers,
Expand All @@ -195,10 +172,9 @@ pub(crate) async fn send(

// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let mut session = sessions
.remove(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

let recipients = if args.recipients.is_empty() {
vec![Vec::new()]
Expand All @@ -221,7 +197,6 @@ pub(crate) async fn send(
}

/// Implement the recv API
// TODO: get identifier from channel rather from arguments
#[tracing::instrument(ret, err(Debug), skip(state, user))]
pub(crate) async fn receive(
State(state): State<SharedState>,
Expand All @@ -235,10 +210,9 @@ pub(crate) async fn receive(
// adds support to it. This will also simplify the code since
// we have to do a workaround in order to not renew the timeout if there
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

let pubkey = if user.pubkey == session.coordinator_pubkey && args.as_coordinator {
Vec::new()
Expand All @@ -252,10 +226,9 @@ pub(crate) async fn receive(
let msgs = if session.queue.contains_key(&pubkey) {
drop(sessions);
let mut sessions = state.sessions.sessions.write().unwrap();
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let mut session = sessions
.remove(&args.session_id)
.ok_or(AppError::SessionNotFound)?;
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
sessions.insert(args.session_id, session);
msgs
Expand All @@ -276,28 +249,20 @@ pub(crate) async fn close_session(
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError::SessionNotFound)?;

if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
return Err(AppError::SessionNotFound);
}

let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

if session.coordinator_pubkey != user.pubkey {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not the coordinator of the session").into(),
));
return Err(AppError::NotCoordinator);
}

for username in session.pubkeys.clone() {
Expand Down
49 changes: 44 additions & 5 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::net::SocketAddr;
use axum_server::tls_rustls::RustlsConfig;
use eyre::OptionExt;
pub use state::{AppState, SharedState};
use thiserror::Error;
use tower_http::trace::TraceLayer;
pub use types::*;

Expand All @@ -17,7 +18,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Router,
Json, Router,
};

/// Create the axum Router for the server.
Expand Down Expand Up @@ -74,12 +75,50 @@ pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {

/// An error. Wraps a StatusCode which is returned by the server when the
/// error happens during a API call, and a generic eyre::Report.
// TODO: create an enum with specific errors
#[derive(Debug)]
pub struct AppError(StatusCode, Box<dyn std::error::Error>);
#[derive(Debug, Error)]
pub(crate) enum AppError {
#[error("invalid or missing argument: {0}")]
InvalidArgument(String),
#[error("client did not provide proper authorization credentials")]
Unauthorized,
#[error("session was not found")]
SessionNotFound,
#[error("user is not the coordinator")]
NotCoordinator,
}

// These make it easier to clients to tell which error happened.
pub const INVALID_ARGUMENT: usize = 1;
pub const UNAUTHORIZED: usize = 2;
pub const SESSION_NOT_FOUND: usize = 3;
pub const NOT_COORDINATOR: usize = 4;

impl AppError {
pub fn error_code(&self) -> usize {
match &self {
AppError::InvalidArgument(_) => INVALID_ARGUMENT,
AppError::Unauthorized => UNAUTHORIZED,
AppError::SessionNotFound => SESSION_NOT_FOUND,
AppError::NotCoordinator => NOT_COORDINATOR,
}
}
}

impl From<AppError> for types::Error {
fn from(err: AppError) -> Self {
types::Error {
code: err.error_code(),
msg: err.to_string(),
}
}
}

impl IntoResponse for AppError {
fn into_response(self) -> Response {
(self.0, format!("{}", self.1)).into_response()
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(Into::<types::Error>::into(self)),
)
.into_response()
}
}
6 changes: 6 additions & 0 deletions server/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ use frost_rerandomized::Randomizer;
use serde::{Deserialize, Serialize};
pub use uuid::Uuid;

#[derive(Debug, Serialize, Deserialize)]
pub struct Error {
pub code: usize,
pub msg: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterArgs {
pub username: String,
Expand Down
29 changes: 5 additions & 24 deletions server/src/user.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
use std::str::FromStr;

use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
RequestPartsExt,
};
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, RequestPartsExt};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use eyre::eyre;
use uuid::Uuid;

use crate::{state::SharedState, AppError};

/// An User
#[derive(Debug)]
#[allow(dead_code)]
pub struct User {
pub(crate) struct User {
pub(crate) pubkey: Vec<u8>,
pub(crate) current_token: Uuid,
}
Expand All @@ -42,19 +36,9 @@ impl FromRequestParts<SharedState> for User {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("Bearer token missing").into(),
)
})?;
.map_err(|_| AppError::Unauthorized)?;
// Decode the user data
let access_token = Uuid::from_str(bearer.token()).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid access token").into(),
)
})?;
let access_token = Uuid::from_str(bearer.token()).map_err(|_| AppError::Unauthorized)?;

let pubkey = state
.access_tokens
Expand All @@ -69,10 +53,7 @@ impl FromRequestParts<SharedState> for User {
current_token: access_token,
})
} else {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("user not found").into(),
));
return Err(AppError::Unauthorized);
}
}
}
22 changes: 19 additions & 3 deletions server/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use server::{
};

use frost_core as frost;
use uuid::Uuid;
use xeddsa::{xed25519, Sign, Verify};

#[tokio::test]
Expand Down Expand Up @@ -450,7 +451,7 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::ChallengeOutput>().await?;
let alice_challenge = r.challenge;
Expand All @@ -469,7 +470,7 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::KeyLoginOutput>().await?;
let access_token = r.access_token;
Expand All @@ -489,12 +490,27 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::CreateNewSessionOutput>().await?;
let session_id = r.session_id;
println!("Session ID: {}", session_id);

// Error test

let wrong_session_id = Uuid::new_v4();
let r = client
.post("http://127.0.0.1:2744/get_session_info")
.bearer_auth(access_token)
.json(&server::GetSessionInfoArgs {
session_id: wrong_session_id,
})
.send()
.await?;
assert_eq!(r.status(), reqwest::StatusCode::INTERNAL_SERVER_ERROR);
let r = r.json::<server::Error>().await?;
assert_eq!(r.code, server::SESSION_NOT_FOUND);

Ok(())
}

Expand Down

0 comments on commit 33f8017

Please sign in to comment.