Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
eightfilms committed Jun 29, 2024
1 parent b6f90f3 commit b4cefae
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 38 deletions.
50 changes: 31 additions & 19 deletions src/encryption/symmetric/aes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module contains the implementation for the Advanced Encryption Standard (AES) encryption
//! and decryption.
#![doc = include_str!("./README.md")]
#![cfg_attr(not(doctest), doc = include_str!("./README.md"))]

use itertools::Itertools;

Expand All @@ -23,6 +23,13 @@ where [(); N / 8]: {
inner: [u8; N / 8],
}

impl<const N: usize> Key<N>
where [(); N / 8]:
{
/// Creates a new `Key` of size `N` bits.
pub fn new(key_bytes: [u8; N / 8]) -> Self { Self { inner: key_bytes } }
}

impl<const N: usize> std::ops::Deref for Key<N>
where [(); N / 8]:
{
Expand All @@ -31,26 +38,31 @@ where [(); N / 8]:
fn deref(&self) -> &Self::Target { &self.inner }
}

impl<const K: usize> SymmetricEncryption for AES<K>
where [(); K / 8]:
impl<const N: usize> SymmetricEncryption for AES<N>
where [(); N / 8]:
{
type Block = Block;
type Key = Key<K>;
type Key = Key<N>;

/// Encrypt a message of size [`Block`] with a [`Key`] of size `K`-bits.
/// Encrypt a message of size [`Block`] with a [`Key`] of size `N`-bits.
///
/// ## Example
/// ```rust
/// #![feature(generic_const_exprs)]
///
/// use rand::{thread_rng, Rng};
/// use ronkathon::encryption::symmetric::{aes::AES, SymmetricEncryption};
/// use ronkathon::encryption::symmetric::{
/// aes::{Key, AES},
/// SymmetricEncryption,
/// };
///
/// let mut rng = thread_rng();
/// let key = Key::<128> { inner: rng.gen() };
/// let key = Key::<128>::new(rng.gen());
/// let plaintext = rng.gen();
/// let encrypted = AES::encrypt(&key, &plaintext);
/// ```
fn encrypt(key: &Self::Key, plaintext: &Self::Block) -> Self::Block {
let num_rounds = match K {
let num_rounds = match N {
128 => 10,
192 => 12,
256 => 14,
Expand All @@ -63,7 +75,10 @@ where [(); K / 8]:
fn decrypt(_key: &Self::Key, _ciphertext: &Self::Block) -> Self::Block { unimplemented!() }
}

/// https://en.wikipedia.org/wiki/AES_key_schedule#Round_constants
/// Contains the values given by [x^(i-1), {00}, {00}, {00}], with x^(i-1)
/// being powers of x in the field GF(2^8).
///
/// NOTE: i starts at 1, not 0.
const ROUND_CONSTANTS: [[u8; 4]; 10] = [
[0x01, 0x00, 0x00, 0x00],
[0x02, 0x00, 0x00, 0x00],
Expand All @@ -79,7 +94,7 @@ const ROUND_CONSTANTS: [[u8; 4]; 10] = [

/// A struct containing an instance of an AES encryption/decryption.
#[derive(Clone)]
pub struct AES<const K: usize> {}
pub struct AES<const N: usize> {}

/// Instead of arranging its bytes in a line (array),
/// AES operates on a grid, specifically a 4x4 column-major array:
Expand All @@ -94,15 +109,15 @@ pub struct AES<const K: usize> {}
#[derive(Debug, Default, Clone, Copy, PartialEq)]
struct State([[u8; 4]; 4]);

impl<const K: usize> AES<K>
where [(); K / 8]:
impl<const N: usize> AES<N>
where [(); N / 8]:
{
/// Performs the cipher, with key size of `K` (in bits), as seen in Figure 5 of the document
/// Performs the cipher, with key size of `N` (in bits), as seen in Figure 5 of the document
/// linked in the front-page.
fn aes_encrypt(plaintext: &[u8; 16], key: &Key<K>, num_rounds: usize) -> Block {
fn aes_encrypt(plaintext: &[u8; 16], key: &Key<N>, num_rounds: usize) -> Block {
assert!(!key.is_empty(), "Key is not instantiated");

let key_len_words = K / 32;
let key_len_words = N / 32;
let mut expanded_key = Vec::with_capacity(key_len_words * (num_rounds + 1));
Self::key_expansion(*key, &mut expanded_key, key_len_words, num_rounds);
let mut expanded_key_chunks = expanded_key.chunks_exact(4);
Expand Down Expand Up @@ -222,16 +237,13 @@ where [(); K / 8]:
word
}

fn key_expansion(key: Key<K>, expanded_key: &mut Vec<Word>, key_len: usize, num_rounds: usize) {
fn key_expansion(key: Key<N>, expanded_key: &mut Vec<Word>, key_len: usize, num_rounds: usize) {
let block_num_words = 128 / 32;

let out_len = block_num_words * (num_rounds + 1);
let key_words: Vec<Word> = key.chunks(4).map(|c| c.try_into().unwrap()).collect();
expanded_key.extend(key_words);

// key len (Nk words)
// block size (Nb words)
// num rounds (Nr)
for i in key_len..(block_num_words * (num_rounds + 1)) {
let mut last = *expanded_key.last().unwrap();

Expand Down
30 changes: 11 additions & 19 deletions src/encryption/symmetric/aes/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ use super::*;
#[test]
fn test_aes_128() {
const KEY_LEN: usize = 128;
let key = Key::<KEY_LEN> {
inner: [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f,
],
};
let key = Key::<KEY_LEN>::new([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
]);

let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
Expand All @@ -27,12 +24,10 @@ fn test_aes_128() {
#[test]
fn test_aes_192() {
const KEY_LEN: usize = 192;
let key = Key::<KEY_LEN> {
inner: [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
],
};
let key = Key::<KEY_LEN>::new([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
]);

let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
Expand All @@ -50,13 +45,10 @@ fn test_aes_192() {
#[test]
fn test_aes_256() {
const KEY_LEN: usize = 256;
let key = Key::<KEY_LEN> {
inner: [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d,
0x1e, 0x1f,
],
};
let key = Key::<KEY_LEN>::new([
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
]);

let plaintext = [
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
Expand Down

0 comments on commit b4cefae

Please sign in to comment.