Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure code into modules #30

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
File renamed without changes.
73 changes: 73 additions & 0 deletions benchmarks/benchmark_ml_kem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256
import cProfile
from time import time


def profile_ml_kem(ML_KEM):
(ek, dk) = ML_KEM.keygen()
(K, c) = ML_KEM.encaps(ek)

gvars = {}
lvars = {"ML_KEM": ML_KEM, "c": c, "ek": ek, "dk": dk}

cProfile.runctx(
"[ML_KEM.keygen() for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)
cProfile.runctx(
"[ML_KEM.encaps(ek) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)
cProfile.runctx(
"[ML_KEM.decaps(c, dk) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)


def benchmark_ml_kem(ML_KEM, name, count):
keygen_times = []
enc_times = []
dec_times = []

for _ in range(count):
t0 = time()
ek, dk = ML_KEM.keygen()
keygen_times.append(time() - t0)

t1 = time()
_, c = ML_KEM.encaps(ek)
enc_times.append(time() - t1)

t2 = time()
_ = ML_KEM.decaps(c, dk)
dec_times.append(time() - t2)

avg_keygen = sum(keygen_times) / count
avg_enc = sum(enc_times) / count
avg_dec = sum(dec_times) / count
print(
f" {name:11} |"
f"{avg_keygen*1000:8.2f}ms {1/avg_keygen:11.2f}"
f"{avg_enc*1000:8.2f}ms {1/avg_enc:10.2f}"
f"{avg_dec*1000:8.2f}ms {1/avg_dec:8.2f}"
)


if __name__ == "__main__":
count = 1000
# common banner
print("-" * 80)
print(
" Params | keygen | keygen/s | encap | encap/s "
"| decap | decap/s"
)
print("-" * 80)
benchmark_ml_kem(ML_KEM128, "ML_KEM128", count)
benchmark_ml_kem(ML_KEM192, "ML_KEM192", count)
benchmark_ml_kem(ML_KEM256, "ML_KEM256", count)
8 changes: 2 additions & 6 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@ API
===

.. toctree::
aes256_ctr_drbg
benchmark_kyber
drbg
kyber
ml_kem
modules
modules_generic
polynomials
polynomials_generic
run_kyber
utils
utilities
4 changes: 2 additions & 2 deletions docs/source/aes256_ctr_drbg.rst → docs/source/drbg.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aes256\_ctr\_drbg module
drbg module
========================

.. automodule:: aes256_ctr_drbg
.. automodule:: drbg
:members:
:undoc-members:
:show-inheritance:
7 changes: 0 additions & 7 deletions docs/source/modules_generic.rst

This file was deleted.

7 changes: 0 additions & 7 deletions docs/source/polynomials_generic.rst

This file was deleted.

7 changes: 0 additions & 7 deletions docs/source/run_kyber.rst

This file was deleted.

4 changes: 2 additions & 2 deletions docs/source/utils.rst → docs/source/utilities.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
utils module
utilities module
============

.. automodule:: utils
.. automodule:: utilities
:members:
:undoc-members:
:show-inheritance:
Empty file added drbg/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions aes256_ctr_drbg.py → drbg/aes256_ctr_drbg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from utils import xor_bytes
from utilities.utils import xor_bytes
from Crypto.Cipher import AES


Expand Down Expand Up @@ -88,7 +88,7 @@ def random_bytes(self, num_bytes, additional=None):
if len(additional) > self.seed_length:
raise ValueError(
f"The additional input must be of length at most: "
f"{self.seed_length}. Input has length {len(seed)}"
f"{self.seed_length}. Input has length {len(additional)}"
)
elif len(additional) < self.seed_length:
additional += bytes([0]) * (self.seed_length - len(additional))
Expand Down
7 changes: 3 additions & 4 deletions kyber.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import PolynomialRingKyber
from modules import ModuleKyber
from modules.modules import ModuleKyber

try:
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG
except ImportError as e:
print(
"Error importing AES CTR DRBG. Have you tried installing requirements?"
Expand Down Expand Up @@ -47,8 +46,8 @@ def __init__(self, parameter_set):
self.du = parameter_set["du"]
self.dv = parameter_set["dv"]

self.R = PolynomialRingKyber()
self.M = ModuleKyber()
self.R = self.M.ring

self.drbg = None
self.random_bytes = os.urandom
Expand Down
7 changes: 3 additions & 4 deletions ml_kem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import PolynomialRingKyber
from modules import ModuleKyber
from modules.modules import ModuleKyber

try:
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG
except ImportError as e:
print(
"Error importing AES CTR DRBG. Have you tried installing requirements?"
Expand Down Expand Up @@ -32,8 +31,8 @@ def __init__(self, params, seed=None):
self.du = params["du"]
self.dv = params["dv"]

self.R = PolynomialRingKyber()
self.M = ModuleKyber()
self.R = self.M.ring

# NIST approved randomness
if seed is None:
Expand Down
Empty file added modules/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions modules.py → modules/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from polynomials import PolynomialRingKyber
from modules_generic import Module, Matrix
from polynomials.polynomials import PolynomialRingKyber
from modules.modules_generic import Module, Matrix


class ModuleKyber(Module):
Expand Down
File renamed without changes.
Empty file added polynomials/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions polynomials.py → polynomials/polynomials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from polynomials_generic import PolynomialRing, Polynomial
from utils import bytes_to_bits, bitstring_to_bytes
from polynomials.polynomials_generic import PolynomialRing, Polynomial
from utilities.utils import bytes_to_bits, bitstring_to_bytes


class PolynomialRingKyber(PolynomialRing):
Expand Down
File renamed without changes.
8 changes: 0 additions & 8 deletions run_kyber.py

This file was deleted.

28 changes: 0 additions & 28 deletions test_ml_kem.py

This file was deleted.

Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion test_kyber.py → tests/test_kyber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import os
from kyber import Kyber512, Kyber768, Kyber1024
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG


def parse_kat_data(data):
Expand Down
7 changes: 6 additions & 1 deletion test_kyber_kat.py → tests/test_kyber_kat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
An alternative way of checking the Kyber KAT
Does nothing which isn't already checked in test_kyber.py
"""

from kyber import Kyber512, Kyber768, Kyber1024
from hashlib import sha256
from aes256_ctr_drbg import AES256_CTR_DRBG
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG


def generate_kat_hash(kyber):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256


class TestML_KEM(unittest.TestCase):
"""
Test ML_KEM levels for internal
consistency by generating key pairs
and shared secrets.
"""

def generic_test_ML_KEM(self, ML_KEM, count):
for _ in range(count):
(ek, dk) = ML_KEM.keygen()
for _ in range(count):
(K, c) = ML_KEM.encaps(ek)
K_prime = ML_KEM.decaps(c, dk)
self.assertEqual(K, K_prime)

def test_ML_KEM128(self):
self.generic_test_ML_KEM(ML_KEM128, 5)

def test_ML_KEM192(self):
self.generic_test_ML_KEM(ML_KEM192, 5)

def test_ML_KEM256(self):
self.generic_test_ML_KEM(ML_KEM256, 5)


if __name__ == "__main__":
unittest.main()
Empty file added utilities/__init__.py
Empty file.
File renamed without changes.