generated from famedly/rust-library-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ hyper-proxy = { version = "0.9.1", default-features = false, features = [ | |
headers = "0.3" | ||
futures = "0.3.30" | ||
pbjson-types = "0.6.0" | ||
josekit = { git = "ssh://[email protected]/famedly/josekit-rs.git", rev = "5941d0f39d034e9c6d96dc47f391926f5e3038fa" } | ||
cache_control = "0.2.0" | ||
|
||
[lints.rust] | ||
dead_code = "warn" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
use std::{collections::HashMap, sync::Arc, time::SystemTime}; | ||
|
||
use cache_control::CacheControl; | ||
use josekit::{jwk::JwkSet, jws::RS256, jwt, jwt::JwtPayload}; | ||
use reqwest::{header, Client, Response}; | ||
use serde::Deserialize; | ||
use time::OffsetDateTime; | ||
use tokio::sync::RwLock; | ||
use url::Url; | ||
|
||
/// Zitadel client to verify a token's validity | ||
#[derive(Debug)] | ||
pub struct ZitadelJWTVerifier { | ||
/// Zitadel domain | ||
domain: Url, | ||
/// Client for performing the requests | ||
client: Client, | ||
/// Key set from Zitadel | ||
jwks: Arc<RwLock<JwkSet>>, | ||
/// Time when the jwks is no longer valid | ||
expires_at: OffsetDateTime, | ||
} | ||
|
||
impl ZitadelJWTVerifier { | ||
/// Creates a new verifier to verify with a specif server | ||
#[must_use] | ||
pub fn new(url: Url) -> Self { | ||
let jwks = JwkSet::new(); | ||
Self { | ||
domain: url, | ||
client: Client::new(), | ||
jwks: Arc::new(RwLock::new(jwks)), | ||
expires_at: OffsetDateTime::now_utc(), | ||
} | ||
} | ||
|
||
/// Verifies if a token is valid and returns the token payload | ||
/// The performed verifications are: | ||
/// - Token signature | ||
/// - Token not expired | ||
/// - Token not used before 'not before' | ||
/// - Token issuer is the expected server | ||
pub async fn verify(&mut self, token: String) -> Result<JwtPayload, TokenValidationError> { | ||
let header = jwt::decode_header(&token)?; | ||
let kid = header | ||
.claim("kid") | ||
.ok_or(TokenValidationError::BadToken("No kid"))? | ||
.as_str() | ||
.ok_or(TokenValidationError::BadToken("kid is not a string"))?; | ||
|
||
let mut jwk = self.jwks.read().await.get(kid).first().map(|&jwk| jwk.clone()); | ||
if self.expires_at < OffsetDateTime::now_utc() || jwk.is_none() { | ||
let (new_jwks, expires_at) = self.get_jwks().await?; | ||
let mut jwks = self.jwks.write().await; | ||
*jwks = new_jwks; | ||
jwk = jwks.get(kid).first().map(|&jwk| jwk.clone()); | ||
self.expires_at = expires_at; | ||
} | ||
|
||
let Some(jwk) = jwk else { | ||
return Err(TokenValidationError::KidNotFoundError); | ||
}; | ||
|
||
let verifier = | ||
RS256.verifier_from_jwk(&jwk).map_err(TokenValidationError::TokenDecodeError)?; | ||
let (payload, _) = jwt::decode_with_verifier(token, &verifier) | ||
.map_err(TokenValidationError::TokenDecodeError)?; | ||
|
||
// Url always comes with an '/' at the end. We need to remove it before for | ||
// checking | ||
if !payload.issuer().is_some_and(|issuer| { | ||
issuer == self.domain.as_str().strip_suffix("/").unwrap_or(self.domain.as_str()) | ||
}) { | ||
return Err(TokenValidationError::TokenIssuerError( | ||
payload.issuer().unwrap_or_default().to_owned(), | ||
)); | ||
} | ||
|
||
let now = OffsetDateTime::now_utc(); | ||
if !payload.expires_at().is_some_and(|t| t > now) { | ||
return Err(TokenValidationError::TokenExpiredError); | ||
} | ||
|
||
if !payload.not_before().is_some_and(|t| t < now) { | ||
return Err(TokenValidationError::TokenNotBeforeError); | ||
} | ||
|
||
Ok(payload) | ||
} | ||
|
||
/// Gets the jwks and the expiration date for it | ||
async fn get_jwks(&self) -> Result<(JwkSet, OffsetDateTime), RenewJwksError> { | ||
let mut url = self.domain.clone(); | ||
url.set_path("oauth/v2/keys"); | ||
let response = self.client.get(url).send().await?; | ||
|
||
let status_code = response.status(); | ||
if !status_code.is_success() { | ||
return Err(RenewJwksError::BadStatusCodeError(status_code)); | ||
} | ||
|
||
let expires_at = Self::get_cache_control(&response); | ||
|
||
let body = response.bytes().await?; | ||
let jwks = JwkSet::from_bytes(body).map_err(RenewJwksError::ParsingTokenError)?; | ||
|
||
Ok((jwks, expires_at)) | ||
} | ||
|
||
/// Retrieves the cache-control information from the header | ||
fn get_cache_control(response: &Response) -> OffsetDateTime { | ||
let cache_control = response | ||
.headers() | ||
.get(header::CACHE_CONTROL) | ||
.map(|c| c.to_str().unwrap_or_default()) | ||
.unwrap_or_default(); | ||
let Some(cache_control) = CacheControl::from_value(cache_control) else { | ||
return OffsetDateTime::now_utc(); | ||
}; | ||
|
||
if cache_control.no_store { | ||
return OffsetDateTime::now_utc(); | ||
} | ||
|
||
let max_age = cache_control.max_age.unwrap_or_default(); | ||
|
||
OffsetDateTime::now_utc() + max_age | ||
} | ||
} | ||
|
||
/// Enum for errors that can happen whilst verifying the token | ||
#[derive(Debug, thiserror::Error)] | ||
pub enum TokenValidationError { | ||
/// Bad token error | ||
#[error("Failed to read kid from token header")] | ||
BadToken(&'static str), | ||
/// Error renewing the jwks | ||
#[error("Failed to renew the jwks: {0}")] | ||
RenewJwksError(#[from] RenewJwksError), | ||
/// kid not found at the token | ||
#[error("Token kid not found on jwks")] | ||
KidNotFoundError, | ||
/// Error decoding and verifying the token | ||
#[error("Failed to decode the token with the verifier: {0}")] | ||
TokenDecodeError(#[from] josekit::JoseError), | ||
/// Wrong issuer error | ||
#[error("The token came from a different issuer then the expected. Token issuer: '{0}'")] | ||
TokenIssuerError(String), | ||
/// Token expired error | ||
#[error("The token has expired")] | ||
TokenExpiredError, | ||
/// Token used before the 'not before' | ||
#[error("The token is still not valid")] | ||
TokenNotBeforeError, | ||
} | ||
|
||
/// Enum for errors that can happen whilst renewing the jwks | ||
#[derive(Debug, thiserror::Error)] | ||
pub enum RenewJwksError { | ||
/// General error from reqwest request | ||
#[error("Failed to do the reqwest: {0}")] | ||
ReqwestError(#[from] reqwest::Error), | ||
/// Requested returned with a bad status code error | ||
#[error("The request returned with a bad status code: {0}")] | ||
BadStatusCodeError(reqwest::StatusCode), | ||
/// Parsing the body as jwks error | ||
#[error("Failed to parse the token: {0}")] | ||
ParsingTokenError(#[from] josekit::JoseError), | ||
} | ||
|
||
/// Struct that represents a JWT from Zitadel | ||
#[derive(Debug)] | ||
pub struct ZitadelJWT { | ||
/// Token Issuer | ||
pub iss: String, | ||
/// Token expiration date | ||
pub exp: OffsetDateTime, | ||
/// Token not before date | ||
pub nbf: Option<OffsetDateTime>, | ||
/// Map of roles to array of projects ids | ||
pub roles: HashMap<ZitadelUserRole, Vec<String>>, | ||
/// Home server | ||
pub home_server: String, | ||
/// Profession oid | ||
pub profession_oid: i64, | ||
} | ||
|
||
/// User roles available on Zitadel | ||
#[derive(Debug, Deserialize, PartialEq, Eq, Hash)] | ||
#[allow(missing_docs)] | ||
pub enum ZitadelUserRole { | ||
TimProviderApi, | ||
FederationlistApi, | ||
OrgAdmin, | ||
} | ||
|
||
impl From<JwtPayload> for ZitadelJWT { | ||
fn from(value: JwtPayload) -> Self { | ||
let iss = value.issuer().map(ToOwned::to_owned).unwrap_or_default(); | ||
let exp: OffsetDateTime = value.expires_at().unwrap_or(SystemTime::UNIX_EPOCH).into(); | ||
let nbf: Option<OffsetDateTime> = value.not_before().map(Into::into); | ||
|
||
let roles: HashMap<ZitadelUserRole, Vec<String>> = value | ||
.claim("roles") | ||
.and_then(|roles| serde_json::from_value(roles.clone()).ok()) | ||
.unwrap_or_default(); | ||
let home_server = | ||
value.claim("homeserver").and_then(|v| v.as_str()).unwrap_or_default().to_owned(); | ||
let profession_oid = | ||
value.claim("professionOID").and_then(serde_json::Value::as_i64).unwrap_or_default(); | ||
|
||
Self { iss, exp, nbf, roles, home_server, profession_oid } | ||
} | ||
} |