Skip to content

Commit

Permalink
feat: Add token validation helper
Browse files Browse the repository at this point in the history
  • Loading branch information
mzaniolo committed Nov 26, 2024
1 parent 0795f7d commit 095afea
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ hyper-proxy = { version = "0.9.1", default-features = false, features = [
headers = "0.3"
futures = "0.3.30"
pbjson-types = "0.6.0"
josekit = { git = "https://github.com/famedly/josekit-rs.git", rev = "5941d0f39d034e9c6d96dc47f391926f5e3038fa" }
cache_control = "0.2.0"

[lints.rust]
dead_code = "warn"
Expand Down
2 changes: 2 additions & 0 deletions src/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//! Communication with Zitadel using http [v2 API](https://zitadel.com/docs/apis/v2)
mod authentication;
/// Helper client to authenticate tokens
pub mod token;
pub mod users;
use std::{path::PathBuf, sync::Arc};

Expand Down
178 changes: 178 additions & 0 deletions src/v2/token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use std::sync::Arc;

use cache_control::CacheControl;
use josekit::{jwk::JwkSet, jws::RS256, jwt, jwt::JwtPayload};
use reqwest::{header, Client, Response};
use time::OffsetDateTime;
use tokio::sync::RwLock;
use url::Url;

/// Zitadel client to verify a token's validity
#[derive(Debug, Clone)]
pub struct ZitadelJWTVerifier {
/// Zitadel domain
domain: Url,
/// Client for performing the requests
client: Client,
/// Key set cache from Zitadel
jwks_cache: Arc<RwLock<JwkSetCache>>,
}

/// Cache of Zitadel jwks
#[derive(Debug, Clone)]
struct JwkSetCache {
/// Key set from Zitadel
jwks: JwkSet,
/// Time when the jwks is no longer valid
expires_at: OffsetDateTime,
}

impl ZitadelJWTVerifier {
/// Creates a new verifier to verify with a specific server
#[must_use]
pub fn new(url: Url) -> Self {
let jwks = JwkSet::new();
Self {
domain: url,
client: Client::new(),
jwks_cache: Arc::new(RwLock::new(JwkSetCache {
jwks,
expires_at: OffsetDateTime::now_utc(),
})),
}
}

/// Verifies if a token is valid and returns the token payload
/// The performed verifications are:
/// - Token signature
/// - Token not expired
/// - Token not used before 'not before'
/// - Token issuer is the expected server
pub async fn verify(&self, token: String) -> Result<JwtPayload, TokenValidationError> {
use TokenValidationError::*;

let header = jwt::decode_header(&token)?;
let kid = header
.claim("kid")
.ok_or(BadToken("No kid"))?
.as_str()
.ok_or(BadToken("kid is not a string"))?;

let (mut jwk, expires_at) = {
let jwks_cache = self.jwks_cache.read().await;
(jwks_cache.jwks.get(kid).first().copied().cloned(), jwks_cache.expires_at)
};
if expires_at < OffsetDateTime::now_utc() || jwk.is_none() {
let mut jwks_cache = self.jwks_cache.write().await;
*jwks_cache = self.get_jwks().await?;
jwk = jwks_cache.jwks.get(kid).first().map(|&jwk| jwk.clone());
tracing::debug!("Updated JWKs");
}

let jwk = jwk.ok_or(KidNotFoundError(kid.to_owned()))?;

let verifier = RS256.verifier_from_jwk(&jwk).map_err(TokenDecodeError)?;
let (payload, _) = jwt::decode_with_verifier(token, &verifier).map_err(TokenDecodeError)?;

// Url always comes with an '/' at the end. We need to remove it before for
// checking
if !payload.issuer().is_some_and(|issuer| {
issuer == self.domain.as_str().strip_suffix("/").unwrap_or(self.domain.as_str())
}) {
return Err(TokenIssuerError(payload.issuer().unwrap_or_default().to_owned()));
}

let now = OffsetDateTime::now_utc();
(payload.expires_at().ok_or(MissingClaim("exp"))? > now)
.then_some(())
.ok_or(TokenExpiredError)?;

(payload.not_before().ok_or(MissingClaim("nbf"))? < now)
.then_some(())
.ok_or(TokenNotBeforeError)?;

Ok(payload)
}

/// Gets the jwks and the expiration date for it
async fn get_jwks(&self) -> Result<JwkSetCache, RenewJwksError> {
let mut url = self.domain.clone();
url.set_path("oauth/v2/keys");
let response = self.client.get(url).send().await?;

let status_code = response.status();
if !status_code.is_success() {
return Err(RenewJwksError::BadStatusCodeError(status_code));
}

let expires_at = Self::get_cache_control(&response);

let body = response.bytes().await?;
let jwks = JwkSet::from_bytes(body).map_err(RenewJwksError::ParsingTokenError)?;

Ok(JwkSetCache { jwks, expires_at })
}

/// Retrieves the cache-control information from the header
fn get_cache_control(response: &Response) -> OffsetDateTime {
let cache_control = response
.headers()
.get(header::CACHE_CONTROL)
.map(|c| c.to_str().unwrap_or_default())
.unwrap_or_default();
let Some(cache_control) = CacheControl::from_value(cache_control) else {
return OffsetDateTime::now_utc();
};

if cache_control.no_store {
return OffsetDateTime::now_utc();
}

let max_age = cache_control.max_age.unwrap_or_default();

OffsetDateTime::now_utc() + max_age
}
}

/// Enum for errors that can happen whilst verifying the token
#[derive(Debug, thiserror::Error)]
pub enum TokenValidationError {
/// Bad token error
#[error("Failed to read kid from token header. {0}")]
BadToken(&'static str),
/// Error renewing the jwks
#[error("Failed to renew the jwks: {0}")]
RenewJwksError(#[from] RenewJwksError),
/// kid not found at the token
#[error("Token kid not found on jwks. kid: {0}")]
KidNotFoundError(String),
/// Error decoding and verifying the token
#[error("Failed to decode the token with the verifier: {0}")]
TokenDecodeError(#[from] josekit::JoseError),
/// Wrong issuer error
#[error("The token came from a different issuer then the expected. Token issuer: '{0}'")]
TokenIssuerError(String),
/// Token expired error
#[error("The token has expired")]
TokenExpiredError,
/// Token used before the 'not before'
#[error("The token is still not valid")]
TokenNotBeforeError,
/// Missing token claim error
#[error("Token missing the claim '{0}'")]
MissingClaim(&'static str),
}

/// Enum for errors that can happen whilst renewing the jwks
#[derive(Debug, thiserror::Error)]
pub enum RenewJwksError {
/// General error from reqwest request
#[error("Failed to do the reqwest: {0}")]
ReqwestError(#[from] reqwest::Error),
/// Requested returned with a bad status code error
#[error("The request returned with a bad status code: {0}")]
BadStatusCodeError(reqwest::StatusCode),
/// Parsing the body as jwks error
#[error("Failed to parse the token: {0}")]
ParsingTokenError(#[from] josekit::JoseError),
}

0 comments on commit 095afea

Please sign in to comment.