Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the KAT vector unit tests #38

Merged
merged 2 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions tests/test_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parse_kat_data(data):
count, seed, pk, sk, ct, ss = [
line.split(" = ")[-1] for line in block_data
]
parsed_data[count] = {
parsed_data[int(count)] = {
"seed": bytes.fromhex(seed),
"pk": bytes.fromhex(pk),
"sk": bytes.fromhex(sk),
Expand Down Expand Up @@ -98,55 +98,40 @@ def test_kyber1024_deterministic(self):
self.generic_test_kyber_deterministic(Kyber1024, 5)


class TestKnownTestValuesDRBG(unittest.TestCase):
"""
We know how the seeds for the KAT are generated, so
let's check against our own implementation.

We only need to test one file, as the seeds are the
same across the three files.
"""

def test_kyber512_known_answer_seed(self):
class TestKnownTestValues(unittest.TestCase):
def generic_test_kyber_known_answer(self, Kyber, filename):
# Set DRBG to generate seeds
entropy_input = bytes([i for i in range(48)])
rng = AES256_CTR_DRBG(entropy_input)

with open("assets/PQCkemKAT_1632.rsp") as f:
# extract data from KAT
kat_data_512 = f.read()
parsed_data = parse_kat_data(kat_data_512)
# Check all seeds match
for data in parsed_data.values():
seed = data["seed"]
self.assertEqual(seed, rng.random_bytes(48))


class TestKnownTestValues(unittest.TestCase):
def generic_test_kyber_known_answer(self, Kyber, filename):
with open(filename) as f:
kat_data = f.read()
parsed_data = parse_kat_data(kat_data)

for data in parsed_data.values():
seed, pk, sk, ct, ss = data.values()
for count in range(100):
# Obtain the kat data for the count
data = parsed_data[count]

# Seed DRBG with KAT seed
Kyber.set_drbg_seed(seed)
# Set the seed and check it matches the KAT
seed = rng.random_bytes(48)
self.assertEqual(seed, data["seed"])

# Assert keygen matches
_pk, _sk = Kyber.keygen()
self.assertEqual(pk, _pk)
self.assertEqual(sk, _sk)
# Seed DRBG with KAT seed
Kyber.set_drbg_seed(seed)

# Assert keygen matches
pk, sk = Kyber.keygen()
self.assertEqual(pk, data["pk"])
self.assertEqual(sk, data["sk"])

# Assert encapsulation matches
_ct, _ss = Kyber.enc(_pk)
self.assertEqual(ct, _ct)
self.assertEqual(ss, _ss)
# Assert encapsulation matches
ct, ss = Kyber.enc(pk)
self.assertEqual(ct, data["ct"])
self.assertEqual(ss, data["ss"])

# Assert decapsulation matches
__ss = Kyber.dec(ct, sk)
self.assertEqual(ss, __ss)
# Assert decapsulation matches
_ss = Kyber.dec(ct, sk)
self.assertEqual(ss, data["ss"])

def test_kyber512_known_answer(self):
return self.generic_test_kyber_known_answer(
Expand Down
56 changes: 0 additions & 56 deletions tests/test_kyber_kat.py

This file was deleted.

55 changes: 34 additions & 21 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256
from drbg.aes256_ctr_drbg import AES256_CTR_DRBG


def read_kat_data(file_name):
Expand All @@ -12,15 +13,12 @@ def read_kat_data(file_name):

def parse_kat_data(data_blocks):
parsed_data = {}

# only test the first 100 for now, running all 1000 is overkill
# for us as it's pretty slow (~165 seconds)
for block in data_blocks[:100]:
for block in data_blocks:
block_data = block.split("\n")[:-1]
count, z, d, msg, seed, pk, sk, ct_n, ss_n, ct, ss = [
line.split(" = ")[-1] for line in block_data
]
parsed_data[count] = {
parsed_data[int(count)] = {
"z": bytes.fromhex(z),
"d": bytes.fromhex(d),
"msg": bytes.fromhex(msg),
Expand Down Expand Up @@ -62,42 +60,57 @@ def test_ML_KEM256(self):

class TestKnownTestValues(unittest.TestCase):
def generic_test_mlkem_known_answer(self, ML_KEM, filename):
# Set DRBG to generate seeds
# https://github.com/post-quantum-cryptography/KAT/tree/main/MLKEM
entropy_input = bytes.fromhex(
"60496cd0a12512800a79161189b055ac3996ad24e578d3c5fc57c1e60fa2eb4e550d08e51e9db7b67f1a616681d9182d"
)
rng = AES256_CTR_DRBG(entropy_input)

# Parse the KAT file data
kat_data_blocks = read_kat_data(filename)
parsed_data = parse_kat_data(kat_data_blocks)

for data in parsed_data.values():
z, d, msg, seed, pk, sk, ct_n, ss_n, ct, ss = data.values()
# Only test the first 100 for now, running all 1000 is overkill
# for us as it's pretty slow (~165 seconds)
for count in range(100):
# Obtain the kat data for the count
data = parsed_data[count]

# Set the seed and check it matches the KAT
seed = rng.random_bytes(48)
self.assertEqual(seed, data["seed"])

# Check that the three chunks of 32 random bytes match
ML_KEM.set_drbg_seed(seed)
_z = ML_KEM.random_bytes(32)
_d = ML_KEM.random_bytes(32)
_msg = ML_KEM.random_bytes(32)
self.assertEqual(z, _z)
self.assertEqual(d, _d)
self.assertEqual(msg, _msg)

z = ML_KEM.random_bytes(32)
d = ML_KEM.random_bytes(32)
msg = ML_KEM.random_bytes(32)
self.assertEqual(z, data["z"])
self.assertEqual(d, data["d"])
self.assertEqual(msg, data["msg"])

# Reset the seed
ML_KEM.set_drbg_seed(seed)

# Assert keygen matches
ek, dk = ML_KEM.keygen()
self.assertEqual(pk, ek)
self.assertEqual(sk, dk)
self.assertEqual(ek, data["pk"])
self.assertEqual(dk, data["sk"])

# Assert encapsulation matches
K, c = ML_KEM.encaps(ek)
self.assertEqual(ct, c)
self.assertEqual(ss, K)
self.assertEqual(K, data["ss"])
self.assertEqual(c, data["ct"])

# Assert decapsulation matches
_c = ML_KEM.decaps(c, dk)
self.assertEqual(ss, _c)
_K = ML_KEM.decaps(c, dk)
self.assertEqual(_K, data["ss"])

# Assert decapsulation with faulty ciphertext
_c_n = ML_KEM.decaps(ct_n, dk)
self.assertEqual(ss_n, _c_n)
ss_n = ML_KEM.decaps(data["ct_n"], dk)
self.assertEqual(ss_n, data["ss_n"])

def test_mlkem_512_known_answer(self):
return self.generic_test_mlkem_known_answer(
Expand Down