Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add packets for caching_sha2_password public key retrieval #119

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/packets/caching_sha2_password.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use std::io;

use crate::{
io::ParseBuf,
proto::{MyDeserialize, MySerialize},
};

define_header!(
PublicKeyRequestHeader,
InvalidPublicKeyRequest("Invalid PublicKeyRequest header"),
0x02
);

/// A client request for a server public RSA key, used by some authentication mechanisms
/// to add a layer of protection to an unsecured channel (see [`PublicKeyResponse`]).
///
/// [`PublicKeyResponse`]: crate::packets::PublicKeyResponse
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct PublicKeyRequest {
__header: PublicKeyRequestHeader,
}

impl PublicKeyRequest {
pub fn new() -> Self {
Self {
__header: PublicKeyRequestHeader::new(),
}
}
}

impl<'de> MyDeserialize<'de> for PublicKeyRequest {
const SIZE: Option<usize> = None;
type Ctx = ();

fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
Ok(Self {
__header: buf.parse(())?,
})
}
}

impl MySerialize for PublicKeyRequest {
fn serialize(&self, buf: &mut Vec<u8>) {
self.__header.serialize(&mut *buf);
}
}

#[cfg(test)]
mod tests {
use crate::{
io::ParseBuf,
packets::caching_sha2_password::PublicKeyRequest,
proto::{MyDeserialize, MySerialize},
};

#[test]
fn should_parse_rsa_public_key_request_packet() {
const RSA_PUBLIC_KEY_REQUEST: &[u8] = b"\x02";

let public_rsa_key_request =
PublicKeyRequest::deserialize((), &mut ParseBuf(RSA_PUBLIC_KEY_REQUEST));

assert!(public_rsa_key_request.is_ok());
}

#[test]
fn should_build_rsa_public_key_request_packet() {
let rsa_public_key_request = PublicKeyRequest::new();

let mut actual = Vec::new();
rsa_public_key_request.serialize(&mut actual);

let expected: Vec<u8> = [0x02].to_vec();

assert_eq!(expected, actual);
}
}
76 changes: 74 additions & 2 deletions src/packets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ macro_rules! define_header {
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
#[error($msg)]
pub struct $err;
pub type $name = ConstU8<$err, $val>;
pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
};
($name:ident, $cmd:ident, $err:ident) => {
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
#[error("Invalid header for {}", stringify!($cmd))]
pub struct $err;
pub type $name = ConstU8<$err, { Command::$cmd as u8 }>;
pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
};
}

Expand Down Expand Up @@ -92,6 +92,7 @@ macro_rules! define_const_bytes {
}

pub mod binlog_request;
pub mod caching_sha2_password;
pub mod session_state_change;

define_const_bytes!(
Expand Down Expand Up @@ -1242,6 +1243,54 @@ impl MySerialize for AuthMoreData<'_> {
}
}

define_header!(
PublicKeyResponseHeader,
InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
0x01
);

/// A server response to a [`PublicKeyRequest`] containing a public RSA key for authentication protection.
///
/// [`PublicKeyRequest`]: crate::packets::caching_sha2_password::PublicKeyRequest
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct PublicKeyResponse<'a> {
__header: PublicKeyResponseHeader,
rsa_key: RawBytes<'a, EofBytes>,
}

impl<'a> PublicKeyResponse<'a> {
pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
Self {
__header: PublicKeyResponseHeader::new(),
rsa_key: RawBytes::new(rsa_key),
}
}

/// The server's RSA public key in PEM format.
pub fn rsa_key(&self) -> Cow<'_, str> {
self.rsa_key.as_str()
}
}

impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
const SIZE: Option<usize> = None;
type Ctx = ();

fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
Ok(Self {
__header: buf.parse(())?,
rsa_key: buf.parse(())?,
})
}
}

impl MySerialize for PublicKeyResponse<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
self.__header.serialize(&mut *buf);
self.rsa_key.serialize(buf);
}
}

define_header!(
AuthSwitchRequestHeader,
InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
Expand Down Expand Up @@ -3889,4 +3938,27 @@ mod test {
"start(4) >= end(4) in GnoInterval".to_string()
);
}

#[test]
fn should_parse_rsa_public_key_response_packet() {
const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";

let rsa_public_key_response =
PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));

assert!(rsa_public_key_response.is_ok());
assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
}

#[test]
fn should_build_rsa_public_key_response_packet() {
let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());

let mut actual = Vec::new();
rsa_public_key_response.serialize(&mut actual);

let expected = b"\x01test".to_vec();

assert_eq!(expected, actual);
}
}