Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/update with added algorithms #15

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
/target
Cargo.lock
Cargo.lock
2 changes: 1 addition & 1 deletion .rustfmt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
imports_granularity="Crate"
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
Expand Down
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ license = "MIT"
edition = "2021"

[dependencies]
base64 = "0.13"
bitflags = "1.2"
base64 = "0.21"
bitflags = "2.4"
generic-array = "0.14"
jsonwebtoken = { version = "8.0", optional = true }
jsonwebtoken = { version = "9.0", optional = true }
num-bigint = { version = "0.4", optional = true }
p256 = { version = "0.10", optional = true, features = ["arithmetic"] }
p256 = { version = "0.13", optional = true, features = ["arithmetic"] }
rand = { version = "0.8", optional = true }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand All @@ -31,7 +31,7 @@ generate = ["p256", "rand"]
thumbprint = ["sha2"]

[dev-dependencies]
jsonwebtoken = "8.0"
jsonwebtoken = "9.0"

[package.metadata.docs.rs]
all-features = true
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

tl;dr: get keys into a format that can be used by other crates; be as safe as possible while doing so.

- Serialization and deserialization of _Required_ and _Recommended_ key types (HS256, RS256, ES256)
- Serialization and deserialization of _Required_ and _Recommended_ key types (HS256, HS384, HS512 RS256, RS384, RS512, ES256, ES384)
- Conversion to PEM for interop with existing JWT libraries (e.g., [jsonwebtoken](https://crates.io/crates/jsonwebtoken))
- Key generation (particularly useful for testing)

Expand Down
2 changes: 1 addition & 1 deletion src/key_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::{
macro_rules! impl_key_ops {
($(($key_op:ident, $const_name:ident, $i:literal)),+,) => {
bitflags::bitflags! {
#[derive(Default)]
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct KeyOps: u16 {
$(const $const_name = $i;)*
}
Expand Down
182 changes: 149 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ mod utils;

use std::{borrow::Cow, fmt};

use generic_array::typenum::U32;
use generic_array::typenum::{U32, U48};
use serde::{Deserialize, Serialize};

pub use byte_array::ByteArray;
Expand Down Expand Up @@ -145,7 +145,7 @@ impl JsonWebKey {
}

pub fn set_algorithm(&mut self, alg: Algorithm) -> Result<(), Error> {
Self::validate_algorithm(alg, &*self.key)?;
Self::validate_algorithm(alg, &self.key)?;
self.algorithm = Some(alg);
Ok(())
}
Expand All @@ -161,11 +161,22 @@ impl JsonWebKey {
(
ES256,
EC {
curve: Curve::P256, ..
curve: Curve::P256 { .. },
..
},
)
| (
ES384,
EC {
curve: Curve::P384 { .. },
},
)
| (RS256, RSA { .. })
| (RS384, RSA { .. })
| (RS512, RSA { .. })
| (HS256, Symmetric { .. }) => Ok(()),
(HS384, Symmetric { .. }) => Ok(()),
(HS512, Symmetric { .. }) => Ok(()),
_ => Err(Error::MismatchedAlgorithm),
}
}
Expand All @@ -180,7 +191,7 @@ impl std::str::FromStr for JsonWebKey {
Some(alg) => alg,
None => return Ok(jwk),
};
Self::validate_algorithm(alg, &*jwk.key).map(|_| jwk)
Self::validate_algorithm(alg, &jwk.key).map(|_| jwk)
}
}

Expand All @@ -200,14 +211,8 @@ impl std::fmt::Display for JsonWebKey {
pub enum Key {
/// An elliptic curve, as per [RFC 7518 §6.2](https://tools.ietf.org/html/rfc7518#section-6.2).
EC {
#[serde(rename = "crv")]
#[serde(flatten)]
curve: Curve,
#[serde(skip_serializing_if = "Option::is_none")]
d: Option<ByteArray<U32>>,
/// The curve point x coordinate.
x: ByteArray<U32>,
/// The curve point y coordinate.
y: ByteArray<U32>,
},
/// An elliptic curve, as per [RFC 7518 §6.3](https://tools.ietf.org/html/rfc7518#section-6.3).
/// See also: [RFC 3447](https://tools.ietf.org/html/rfc3447).
Expand Down Expand Up @@ -241,7 +246,19 @@ impl Key {
use serde::ser::{SerializeStruct, Serializer};
let mut s = serde_json::Serializer::new(Vec::new());
match self {
Self::EC { curve, x, y, .. } => {
Self::EC {
curve: curve @ Curve::P256 { x, y, .. },
} => {
let mut ss = s.serialize_struct("", 4)?;
ss.serialize_field("crv", curve.name())?;
ss.serialize_field("kty", "EC")?;
ss.serialize_field("x", x)?;
ss.serialize_field("y", y)?;
ss.end()?;
}
Self::EC {
curve: curve @ Curve::P384 { x, y, .. },
} => {
let mut ss = s.serialize_struct("", 4)?;
ss.serialize_field("crv", curve.name())?;
ss.serialize_field("kty", "EC")?;
Expand Down Expand Up @@ -277,7 +294,14 @@ impl Key {
matches!(
self,
Self::Symmetric { .. }
| Self::EC { d: Some(_), .. }
| Self::EC {
curve: Curve::P256 { d: Some(_), .. },
..
}
| Self::EC {
curve: Curve::P384 { d: Some(_), .. },
..
}
| Self::RSA {
private: Some(_),
..
Expand All @@ -292,11 +316,23 @@ impl Key {
}
Some(Cow::Owned(match self {
Self::Symmetric { .. } => return None,
Self::EC { curve, x, y, .. } => Self::EC {
curve: *curve,
x: x.clone(),
y: y.clone(),
d: None,
Self::EC {
curve: Curve::P256 { x, y, .. },
} => Self::EC {
curve: Curve::P256 {
x: x.clone(),
y: y.clone(),
d: None,
},
},
Self::EC {
curve: Curve::P384 { x, y, .. },
} => Self::EC {
curve: Curve::P384 {
x: x.clone(),
y: y.clone(),
d: None,
},
},
Self::RSA { public, .. } => Self::RSA {
public: public.clone(),
Expand All @@ -318,7 +354,9 @@ impl Key {
}

Ok(match self {
Self::EC { d, x, y, .. } => {
Self::EC {
curve: Curve::P256 { d, x, y },
} => {
let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
let prime256v1_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 3, 1, 7]);
let oids = &[Some(&ec_public_oid), Some(&prime256v1_oid)];
Expand All @@ -337,7 +375,7 @@ impl Key {
Some(private_point) => {
pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
writer.next().write_i8(1); // version
writer.next().write_bytes(&**private_point);
writer.next().write_bytes(private_point);
// The following tagged value is optional. OpenSSL produces it,
// but many tools, including jwt.io and `jsonwebtoken`, don't like it,
// so we don't include it.
Expand All @@ -350,6 +388,34 @@ impl Key {
None => pkcs8::write_public(oids, write_public),
}
}
Self::EC {
curve: Curve::P384 { d, x, y },
} => {
let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
let prime384v1_oid = ObjectIdentifier::from_slice(&[1, 3, 132, 0, 34]);
let oids = &[Some(&ec_public_oid), Some(&prime384v1_oid)];

let write_public = |writer: DERWriter<'_>| {
let public_bytes: Vec<u8> = [0x04 /* uncompressed */]
.iter()
.chain(x.iter())
.chain(y.iter())
.copied()
.collect();
writer.write_bitvec_bytes(&public_bytes, 8 * (48 * 2 + 1));
};

match d {
Some(private_point) => {
pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
writer.next().write_i8(1); // version
writer.next().write_bytes(&**private_point);
writer.next().write_tagged(Tag::context(1), write_public);
})
}
None => pkcs8::write_public(oids, write_public),
}
}
Self::RSA { public, private } => {
let rsa_encryption_oid = ObjectIdentifier::from_slice(&[
1, 2, 840, 113549, 1, 1, 1, // rsaEncryption
Expand Down Expand Up @@ -411,8 +477,9 @@ impl Key {
/// If this key is asymmetric, encodes it as PKCS#8 with PEM armoring.
#[cfg(feature = "pkcs-convert")]
pub fn try_to_pem(&self) -> Result<String, ConversionError> {
use base64::{engine::general_purpose::STANDARD, Engine};
use std::fmt::Write;
let der_b64 = base64::encode(self.try_to_der()?);
let der_b64 = STANDARD.encode(self.try_to_der()?);
let key_ty = if self.is_private() {
"PRIVATE"
} else {
Expand Down Expand Up @@ -468,32 +535,65 @@ impl Key {
let (x_bytes, y_bytes) = pk_bytes.split_at(32);

Self::EC {
curve: Curve::P256,
d: Some(sk_scalar.to_bytes().into()),
x: ByteArray::from_slice(x_bytes),
y: ByteArray::from_slice(y_bytes),
curve: Curve::P256 {
d: Some(sk_scalar.to_bytes().into()),
x: ByteArray::from_slice(x_bytes),
y: ByteArray::from_slice(y_bytes),
},
}
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "crv")]
pub enum Curve {
/// The prime256v1 (P256) curve.
/// Parameters of the prime256v1 (P256) curve.
#[serde(rename = "P-256")]
P256,
P256 {
/// The private scalar.
#[serde(skip_serializing_if = "Option::is_none")]
d: Option<ByteArray<U32>>,
/// The curve point x coordinate.
x: ByteArray<U32>,
/// The curve point y coordinate.
y: ByteArray<U32>,
},
/// Parameters of the prime384v1 (P384) curve.
#[serde(rename = "P-384")]
P384 {
/// The private scalar.
#[serde(skip_serializing_if = "Option::is_none")]
d: Option<ByteArray<U48>>,
/// The curve point x coordinate.
x: ByteArray<U48>,
/// The curve point y coordinate.
y: ByteArray<U48>,
},
}

impl Curve {
pub fn name(&self) -> &'static str {
match self {
Self::P256 => "P-256",
Self::P256 { .. } => "P-256",
Self::P384 { .. } => "P-256",
}
}
}

impl fmt::Display for Curve {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
match self {
Self::P256 { x, y, .. } => f
.debug_struct("Curve::P256")
.field("x", x)
.field("y", y)
.finish(),
Self::P384 { x, y, .. } => f
.debug_struct("Curve::P384")
.field("x", x)
.field("y", y)
.finish(),
}
}
}

Expand Down Expand Up @@ -560,28 +660,38 @@ impl fmt::Debug for RsaPrivate {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum KeyUse {
#[serde(rename = "sig")]
Signing,
#[serde(rename = "enc")]
Encryption,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[allow(clippy::upper_case_acronyms)]
pub enum Algorithm {
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
}

impl Algorithm {
pub fn name(&self) -> &'static str {
match self {
Self::HS256 => "hs256",
Self::HS384 => "hs384",
Self::HS512 => "hs512",
Self::RS256 => "rs256",
Self::RS384 => "rs384",
Self::RS512 => "rs512",
Self::ES256 => "es256",
Self::ES384 => "es384",
}
}
}
Expand All @@ -594,8 +704,13 @@ const _IMPL_JWT_CONVERSIONS: () = {
fn from(alg: Algorithm) -> Self {
match alg {
Algorithm::HS256 => Self::HS256,
Algorithm::HS384 => Self::HS384,
Algorithm::HS512 => Self::HS512,
Algorithm::ES256 => Self::ES256,
Algorithm::ES384 => Self::ES384,
Algorithm::RS256 => Self::RS256,
Algorithm::RS384 => Self::RS384,
Algorithm::RS512 => Self::RS512,
}
}
}
Expand Down Expand Up @@ -635,7 +750,8 @@ const _IMPL_JWT_CONVERSIONS: () = {
.unwrap()
}
Self::RSA { .. } => {
jwt::DecodingKey::from_rsa_pem(self.to_pem().as_bytes()).unwrap()
jwt::DecodingKey::from_rsa_pem(self.to_public().unwrap().to_pem().as_bytes())
.unwrap()
}
}
}
Expand Down
Loading