Skip to content

Commit

Permalink
feat: Add token validation helper
Browse files Browse the repository at this point in the history
  • Loading branch information
mzaniolo committed Nov 20, 2024
1 parent 0795f7d commit 20142c5
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ hyper-proxy = { version = "0.9.1", default-features = false, features = [
headers = "0.3"
futures = "0.3.30"
pbjson-types = "0.6.0"
josekit = { git = "ssh://[email protected]/famedly/josekit-rs.git", rev = "5941d0f39d034e9c6d96dc47f391926f5e3038fa" }
cache_control = "0.2.0"

[lints.rust]
dead_code = "warn"
Expand Down
2 changes: 2 additions & 0 deletions src/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//! Communication with Zitadel using http [v2 API](https://zitadel.com/docs/apis/v2)
mod authentication;
/// Helper client to authenticate tokens
pub mod token;
pub mod users;
use std::{path::PathBuf, sync::Arc};

Expand Down
214 changes: 214 additions & 0 deletions src/v2/token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
use std::{collections::HashMap, sync::Arc, time::SystemTime};

use cache_control::CacheControl;
use josekit::{jwk::JwkSet, jws::RS256, jwt, jwt::JwtPayload};
use reqwest::{header, Client, Response};
use serde::Deserialize;
use time::OffsetDateTime;
use tokio::sync::RwLock;
use url::Url;

/// Zitadel client to verify a token's validity
#[derive(Debug)]
pub struct ZitadelJWTVerifier {
/// Zitadel domain
domain: Url,
/// Client for performing the requests
client: Client,
/// Key set from Zitadel
jwks: Arc<RwLock<JwkSet>>,
/// Time when the jwks is no longer valid
expires_at: OffsetDateTime,
}

impl ZitadelJWTVerifier {
/// Creates a new verifier to verify with a specif server
#[must_use]
pub fn new(url: Url) -> Self {
let jwks = JwkSet::new();
Self {
domain: url,
client: Client::new(),
jwks: Arc::new(RwLock::new(jwks)),
expires_at: OffsetDateTime::now_utc(),
}
}

/// Verifies if a token is valid and returns the token payload
/// The performed verifications are:
/// - Token signature
/// - Token not expired
/// - Token not used before 'not before'
/// - Token issuer is the expected server
pub async fn verify(&mut self, token: String) -> Result<JwtPayload, TokenValidationError> {
let header = jwt::decode_header(&token)?;
let kid = header
.claim("kid")
.ok_or(TokenValidationError::BadToken("No kid"))?
.as_str()
.ok_or(TokenValidationError::BadToken("kid is not a string"))?;

let mut jwk = self.jwks.read().await.get(kid).first().map(|&jwk| jwk.clone());
if self.expires_at < OffsetDateTime::now_utc() || jwk.is_none() {
let (new_jwks, expires_at) = self.get_jwks().await?;
let mut jwks = self.jwks.write().await;
*jwks = new_jwks;
jwk = jwks.get(kid).first().map(|&jwk| jwk.clone());
self.expires_at = expires_at;
}

let Some(jwk) = jwk else {
return Err(TokenValidationError::KidNotFoundError);
};

let verifier =
RS256.verifier_from_jwk(&jwk).map_err(TokenValidationError::TokenDecodeError)?;
let (payload, _) = jwt::decode_with_verifier(token, &verifier)
.map_err(TokenValidationError::TokenDecodeError)?;

// Url always comes with an '/' at the end. We need to remove it before for
// checking
if !payload.issuer().is_some_and(|issuer| {
issuer == self.domain.as_str().strip_suffix("/").unwrap_or(self.domain.as_str())
}) {
return Err(TokenValidationError::TokenIssuerError(
payload.issuer().unwrap_or_default().to_owned(),
));
}

let now = OffsetDateTime::now_utc();
if !payload.expires_at().is_some_and(|t| t > now) {
return Err(TokenValidationError::TokenExpiredError);
}

if !payload.not_before().is_some_and(|t| t < now) {
return Err(TokenValidationError::TokenNotBeforeError);
}

Ok(payload)
}

/// Gets the jwks and the expiration date for it
async fn get_jwks(&self) -> Result<(JwkSet, OffsetDateTime), RenewJwksError> {
let mut url = self.domain.clone();
url.set_path("oauth/v2/keys");
let response = self.client.get(url).send().await?;

let status_code = response.status();
if !status_code.is_success() {
return Err(RenewJwksError::BadStatusCodeError(status_code));
}

let expires_at = Self::get_cache_control(&response);

let body = response.bytes().await?;
let jwks = JwkSet::from_bytes(body).map_err(RenewJwksError::ParsingTokenError)?;

Ok((jwks, expires_at))
}

/// Retrieves the cache-control information from the header
fn get_cache_control(response: &Response) -> OffsetDateTime {
let cache_control = response
.headers()
.get(header::CACHE_CONTROL)
.map(|c| c.to_str().unwrap_or_default())
.unwrap_or_default();
let Some(cache_control) = CacheControl::from_value(cache_control) else {
return OffsetDateTime::now_utc();
};

if cache_control.no_store {
return OffsetDateTime::now_utc();
}

let max_age = cache_control.max_age.unwrap_or_default();

OffsetDateTime::now_utc() + max_age
}
}

/// Enum for errors that can happen whilst verifying the token
#[derive(Debug, thiserror::Error)]
pub enum TokenValidationError {
/// Bad token error
#[error("Failed to read kid from token header")]
BadToken(&'static str),
/// Error renewing the jwks
#[error("Failed to renew the jwks: {0}")]
RenewJwksError(#[from] RenewJwksError),
/// kid not found at the token
#[error("Token kid not found on jwks")]
KidNotFoundError,
/// Error decoding and verifying the token
#[error("Failed to decode the token with the verifier: {0}")]
TokenDecodeError(#[from] josekit::JoseError),
/// Wrong issuer error
#[error("The token came from a different issuer then the expected. Token issuer: '{0}'")]
TokenIssuerError(String),
/// Token expired error
#[error("The token has expired")]
TokenExpiredError,
/// Token used before the 'not before'
#[error("The token is still not valid")]
TokenNotBeforeError,
}

/// Enum for errors that can happen whilst renewing the jwks
#[derive(Debug, thiserror::Error)]
pub enum RenewJwksError {
/// General error from reqwest request
#[error("Failed to do the reqwest: {0}")]
ReqwestError(#[from] reqwest::Error),
/// Requested returned with a bad status code error
#[error("The request returned with a bad status code: {0}")]
BadStatusCodeError(reqwest::StatusCode),
/// Parsing the body as jwks error
#[error("Failed to parse the token: {0}")]
ParsingTokenError(#[from] josekit::JoseError),
}

/// Struct that represents a JWT from Zitadel
#[derive(Debug)]
pub struct ZitadelJWT {
/// Token Issuer
pub iss: String,
/// Token expiration date
pub exp: OffsetDateTime,
/// Token not before date
pub nbf: Option<OffsetDateTime>,
/// Map of roles to array of projects ids
pub roles: HashMap<ZitadelUserRole, Vec<String>>,
/// Home server
pub home_server: String,
/// Profession oid
pub profession_oid: i64,
}

/// User roles available on Zitadel
#[derive(Debug, Deserialize, PartialEq, Eq, Hash)]
#[allow(missing_docs)]
pub enum ZitadelUserRole {
TimProviderApi,
FederationlistApi,
OrgAdmin,
}

impl From<JwtPayload> for ZitadelJWT {
fn from(value: JwtPayload) -> Self {
let iss = value.issuer().map(ToOwned::to_owned).unwrap_or_default();
let exp: OffsetDateTime = value.expires_at().unwrap_or(SystemTime::UNIX_EPOCH).into();
let nbf: Option<OffsetDateTime> = value.not_before().map(Into::into);

let roles: HashMap<ZitadelUserRole, Vec<String>> = value
.claim("roles")
.and_then(|roles| serde_json::from_value(roles.clone()).ok())
.unwrap_or_default();
let home_server =
value.claim("homeserver").and_then(|v| v.as_str()).unwrap_or_default().to_owned();
let profession_oid =
value.claim("professionOID").and_then(serde_json::Value::as_i64).unwrap_or_default();

Self { iss, exp, nbf, roles, home_server, profession_oid }
}
}

0 comments on commit 20142c5

Please sign in to comment.