Skip to content

Commit

Permalink
Use Arc<[T]> over Vec<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
WKHAllen committed Jan 5, 2024
1 parent 6c6f0fa commit 7191a9a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
7 changes: 3 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,7 @@ where
// Generate AES key
let aes_key = aes_key().await;
// Encrypt AES key with RSA public key
let aes_key_encrypted =
into_generic_io_result(rsa_encrypt(rsa_pub, aes_key.to_vec()).await)?;
let aes_key_encrypted = into_generic_io_result(rsa_encrypt(rsa_pub, aes_key.into()).await)?;
// Create the buffer containing the AES key and its size
let mut aes_key_buffer = encode_message_size(aes_key_encrypted.len()).to_vec();
// Extend the buffer with the AES key
Expand Down Expand Up @@ -942,7 +941,7 @@ where
}

// Decrypt the data
let data_buffer = match aes_decrypt(aes_key, encrypted_data_buffer).await {
let data_buffer = match aes_decrypt(aes_key, encrypted_data_buffer.into()).await {
Ok(val) => Ok(val),
Err(e) => generic_io_error(format!("failed to decrypt data: {}", e)),
}?;
Expand Down Expand Up @@ -987,7 +986,7 @@ where
// Serialize the data
let data_buffer = break_on_err!(into_generic_io_result(serde_json::to_vec(&data)), 'val);
// Encrypt the serialized data
let encrypted_data_buffer = break_on_err!(into_generic_io_result(aes_encrypt(aes_key, data_buffer).await), 'val);
let encrypted_data_buffer = break_on_err!(into_generic_io_result(aes_encrypt(aes_key, data_buffer.into()).await), 'val);
// Encode the message size to a buffer
let size_buffer = encode_message_size(encrypted_data_buffer.len());

Expand Down
9 changes: 5 additions & 4 deletions src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use aes_gcm::aead::{Aead, KeyInit, OsRng};
use aes_gcm::{Aes256Gcm, Nonce};
use rsa::sha2::Sha256;
use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
use std::sync::Arc;

/// The number of bits to use for an RSA key.
pub const RSA_KEY_SIZE: usize = 2048;
Expand Down Expand Up @@ -88,7 +89,7 @@ pub async fn rsa_keys() -> Result<(RsaPublicKey, RsaPrivateKey)> {
/// `plaintext`: the data to encrypt.
///
/// Returns a result containing the encrypted data, or the error variant if an error occurred while encrypting.
pub async fn rsa_encrypt(public_key: RsaPublicKey, plaintext: Vec<u8>) -> Result<Vec<u8>> {
pub async fn rsa_encrypt(public_key: RsaPublicKey, plaintext: Arc<[u8]>) -> Result<Vec<u8>> {
tokio::task::spawn_blocking(move || {
let mut rng = rand::thread_rng();
let padding = Oaep::new::<Sha256>();
Expand All @@ -106,7 +107,7 @@ pub async fn rsa_encrypt(public_key: RsaPublicKey, plaintext: Vec<u8>) -> Result
/// `ciphertext`: the data to decrypt.
///
/// Returns a result containing the decrypted data, or the error variant if an error occurred while decrypting.
pub async fn rsa_decrypt(private_key: RsaPrivateKey, ciphertext: Vec<u8>) -> Result<Vec<u8>> {
pub async fn rsa_decrypt(private_key: RsaPrivateKey, ciphertext: Arc<[u8]>) -> Result<Vec<u8>> {
tokio::task::spawn_blocking(move || {
let padding = Oaep::new::<Sha256>();
let plaintext = private_key.decrypt(padding, &ciphertext[..])?;
Expand Down Expand Up @@ -135,7 +136,7 @@ pub async fn aes_key() -> [u8; AES_KEY_SIZE] {
/// `plaintext`: the data to encrypt.
///
/// Returns a result containing the encrypted data with the nonce prepended, or the error variant if an error occurred while encrypting.
pub async fn aes_encrypt(key: [u8; AES_KEY_SIZE], plaintext: Vec<u8>) -> Result<Vec<u8>> {
pub async fn aes_encrypt(key: [u8; AES_KEY_SIZE], plaintext: Arc<[u8]>) -> Result<Vec<u8>> {
tokio::task::spawn_blocking(move || {
let cipher = Aes256Gcm::new_from_slice(&key).unwrap();
let nonce_slice: [u8; AES_NONCE_SIZE] = rand::random();
Expand All @@ -159,7 +160,7 @@ pub async fn aes_encrypt(key: [u8; AES_KEY_SIZE], plaintext: Vec<u8>) -> Result<
/// Returns a result containing the decrypted data, or the error variant if an error occurred while decrypting.
pub async fn aes_decrypt(
key: [u8; AES_KEY_SIZE],
ciphertext_with_nonce: Vec<u8>,
ciphertext_with_nonce: Arc<[u8]>,
) -> Result<Vec<u8>> {
tokio::task::spawn_blocking(move || {
let cipher = Aes256Gcm::new_from_slice(&key).unwrap();
Expand Down
31 changes: 18 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ mod tests {
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::mpsc::{channel, Sender};

/// Default amount of time to sleep, in milliseconds.
Expand Down Expand Up @@ -207,35 +208,39 @@ mod tests {
async fn test_crypto() {
let rsa_message = "Hello, RSA!";
let (public_key, private_key) = crypto::rsa_keys().await.unwrap();
let rsa_encrypted =
crypto::rsa_encrypt(public_key.clone(), rsa_message.as_bytes().to_vec())
let rsa_encrypted = Arc::<[u8]>::from(
crypto::rsa_encrypt(public_key.clone(), rsa_message.as_bytes().into())
.await
.unwrap();
let rsa_decrypted = crypto::rsa_decrypt(private_key.clone(), rsa_encrypted.clone())
.unwrap(),
);
let rsa_decrypted = crypto::rsa_decrypt(private_key.clone(), Arc::clone(&rsa_encrypted))
.await
.unwrap();
let rsa_decrypted_message = std::str::from_utf8(&rsa_decrypted).unwrap();
assert_eq!(rsa_decrypted_message, rsa_message);
assert_ne!(rsa_encrypted, rsa_message.as_bytes());
assert_ne!(&*rsa_encrypted, rsa_message.as_bytes());

let aes_message = "Hello, AES!";
let key = crypto::aes_key().await;
let aes_encrypted = crypto::aes_encrypt(key, aes_message.as_bytes().to_vec())
.await
.unwrap();
let aes_decrypted = crypto::aes_decrypt(key, aes_encrypted.clone())
let aes_encrypted = Arc::<[u8]>::from(
crypto::aes_encrypt(key, aes_message.as_bytes().into())
.await
.unwrap(),
);
let aes_decrypted = crypto::aes_decrypt(key, Arc::clone(&aes_encrypted))
.await
.unwrap();
let aes_decrypted_message = std::str::from_utf8(&aes_decrypted).unwrap();
assert_eq!(aes_decrypted_message, aes_message);
assert_ne!(aes_encrypted, aes_message.as_bytes());
assert_ne!(&*aes_encrypted, aes_message.as_bytes());

let encrypted_key = crypto::rsa_encrypt(public_key, key.to_vec()).await.unwrap();
let decrypted_key = crypto::rsa_decrypt(private_key, encrypted_key.clone())
let encrypted_key =
Arc::<[u8]>::from(crypto::rsa_encrypt(public_key, key.into()).await.unwrap());
let decrypted_key = crypto::rsa_decrypt(private_key, Arc::clone(&encrypted_key))
.await
.unwrap();
assert_eq!(decrypted_key, key);
assert_ne!(encrypted_key, key);
assert_ne!(&*encrypted_key, key);
}

/// Test server creation and serving.
Expand Down
8 changes: 4 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ where
}

// Decrypt the data
let data_buffer = match aes_decrypt(aes_key, encrypted_data_buffer).await {
let data_buffer = match aes_decrypt(aes_key, encrypted_data_buffer.into()).await {
Ok(val) => Ok(val),
Err(e) => generic_io_error(format!("failed to decrypt data: {}", e)),
}?;
Expand Down Expand Up @@ -1163,9 +1163,8 @@ where
match client_command {
ServerClientCommand::Send { data } => {
let value = 'val: {
let data_buffer = data.to_vec();
// Encrypt the serialized data
let encrypted_data_buffer = break_on_err!(into_generic_io_result(aes_encrypt(aes_key, data_buffer).await), 'val);
let encrypted_data_buffer = break_on_err!(into_generic_io_result(aes_encrypt(aes_key, data).await), 'val);
// Encode the message size to a buffer
let size_buffer = encode_message_size(encrypted_data_buffer.len());

Expand Down Expand Up @@ -1314,7 +1313,8 @@ where
}

// Decrypt the AES key
let aes_key_decrypted = into_generic_io_result(rsa_decrypt(rsa_priv, aes_key_buffer).await)?;
let aes_key_decrypted =
into_generic_io_result(rsa_decrypt(rsa_priv, aes_key_buffer.into()).await)?;

// Assert that the AES key is the correct size
let aes_key: [u8; AES_KEY_SIZE] = match aes_key_decrypted.try_into() {
Expand Down

0 comments on commit 7191a9a

Please sign in to comment.