Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include ML-KEM KAT vectors #33

Merged
merged 4 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11,000 changes: 11,000 additions & 0 deletions assets/kat_MLKEM_1024.rsp

Large diffs are not rendered by default.

11,000 changes: 11,000 additions & 0 deletions assets/kat_MLKEM_512.rsp

Large diffs are not rendered by default.

11,000 changes: 11,000 additions & 0 deletions assets/kat_MLKEM_768.rsp

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class Kyber:
def __init__(self, parameter_set):
def __init__(self, parameter_set, seed=None):
self.k = parameter_set["k"]
self.eta_1 = parameter_set["eta_1"]
self.eta_2 = parameter_set["eta_2"]
Expand All @@ -23,8 +23,11 @@ def __init__(self, parameter_set):
self.M = ModuleKyber()
self.R = self.M.ring

self.drbg = None
self.random_bytes = os.urandom
# NIST approved randomness
if seed is None:
seed = os.urandom(48)
self._drbg = AES256_CTR_DRBG(seed)
self.random_bytes = self._drbg.random_bytes

def set_drbg_seed(self, seed):
"""
Expand Down
2 changes: 1 addition & 1 deletion ml_kem/default_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DEFAULT_PARAMETERS = {
"ML128": {"k": 2, "eta_1": 3, "eta_2": 2, "du": 10, "dv": 4},
"ML192": {"k": 3, "eta_1": 2, "eta_2": 2, "du": 10, "dv": 4},
"ML256": {"k": 4, "eta_1": 3, "eta_2": 2, "du": 11, "dv": 5},
"ML256": {"k": 4, "eta_1": 2, "eta_2": 2, "du": 11, "dv": 5},
}

ML_KEM128 = ML_KEM(DEFAULT_PARAMETERS["ML128"])
Expand Down
31 changes: 28 additions & 3 deletions ml_kem/ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,37 @@ def __init__(self, params, seed=None):
self._drbg = AES256_CTR_DRBG(seed)
self.random_bytes = self._drbg.random_bytes

def set_drbg_seed(self, seed):
"""
Setting the seed switches the entropy source
from os.urandom to AES256 CTR DRBG

Note: requires pycryptodome for AES impl.
(Seemed overkill to code my own AES for Kyber)
"""
self.drbg = AES256_CTR_DRBG(seed)
self.random_bytes = self.drbg.random_bytes

def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.

Note: requires pycryptodome for AES impl.
(Seemed overkill to code my own AES for Kyber)
"""
if self.drbg is None:
raise Warning(
"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`"
)
else:
self.drbg.reseed(seed)

@staticmethod
def xof(bytes32, a, b, length):
def xof(bytes32, i, j, length):
"""
XOF: B^* x B x B -> B*
"""
input_bytes = bytes32 + a + b
input_bytes = bytes32 + i + j
if len(input_bytes) != 34:
raise ValueError(
"Input bytes should be one 32 byte array and 2 single bytes."
Expand Down Expand Up @@ -106,7 +131,7 @@ def pke_keygen(self):

N = 0
s, N = self.generate_vector(sigma, self.eta_1, N)
e, N = self.generate_vector(sigma, self.eta_2, N)
e, N = self.generate_vector(sigma, self.eta_1, N)

# TODO: we could convert to ntt form as we create the data
# and skip this call to compute a new Matrix objects
Expand Down
2 changes: 1 addition & 1 deletion polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def cbd(self, input_bytes, eta, is_ntt=False):
for i in range(256):
a = sum(list_of_bits[2 * i * eta + j] for j in range(eta))
b = sum(list_of_bits[2 * i * eta + eta + j] for j in range(eta))
coefficients[i] = a - b
coefficients[i] = (a - b) % 3329
return self(coefficients, is_ntt=is_ntt)

def decode(self, input_bytes, l=None, is_ntt=False):
Expand Down
88 changes: 88 additions & 0 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,39 @@
from ml_kem import ML_KEM128, ML_KEM192, ML_KEM256


def read_kat_data(file_name):
data_blocks = []
with open(file_name) as f:
for _ in range(1000):
data_blocks.append("".join([next(f) for _ in range(11)]))
return data_blocks


def parse_kat_data(data_blocks):
parsed_data = {}

# only test the first 100 for now, running all 1000 is overkill
# for us as it's pretty slow (~165 seconds)
for block in data_blocks[:100]:
block_data = block.split("\n")[:-1]
count, z, d, msg, seed, pk, sk, ct_n, ss_n, ct, ss = [
line.split(" = ")[-1] for line in block_data
]
parsed_data[count] = {
"z": bytes.fromhex(z),
"d": bytes.fromhex(d),
"msg": bytes.fromhex(msg),
"seed": bytes.fromhex(seed),
"pk": bytes.fromhex(pk),
"sk": bytes.fromhex(sk),
"ct_n": bytes.fromhex(ct_n),
"ss_n": bytes.fromhex(ss_n),
"ct": bytes.fromhex(ct),
"ss": bytes.fromhex(ss),
}
return parsed_data


class TestML_KEM(unittest.TestCase):
"""
Test ML_KEM levels for internal
Expand All @@ -27,5 +60,60 @@ def test_ML_KEM256(self):
self.generic_test_ML_KEM(ML_KEM256, 5)


class TestKnownTestValues(unittest.TestCase):
def generic_test_mlkem_known_answer(self, ML_KEM, filename):

kat_data_blocks = read_kat_data(filename)
parsed_data = parse_kat_data(kat_data_blocks)

for data in parsed_data.values():
z, d, msg, seed, pk, sk, ct_n, ss_n, ct, ss = data.values()

# Check that the three chunks of 32 random bytes match
ML_KEM.set_drbg_seed(seed)
_z = ML_KEM.random_bytes(32)
_d = ML_KEM.random_bytes(32)
_msg = ML_KEM.random_bytes(32)
self.assertEqual(z, _z)
self.assertEqual(d, _d)
self.assertEqual(msg, _msg)

# Reset the seed
ML_KEM.set_drbg_seed(seed)

# Assert keygen matches
ek, dk = ML_KEM.keygen()
self.assertEqual(pk, ek)
self.assertEqual(sk, dk)

# Assert encapsulation matches
K, c = ML_KEM.encaps(ek)
self.assertEqual(ct, c)
self.assertEqual(ss, K)

# Assert decapsulation matches
_c = ML_KEM.decaps(c, dk)
self.assertEqual(ss, _c)

# Assert decapsulation with faulty ciphertext
_c_n = ML_KEM.decaps(ct_n, dk)
self.assertEqual(ss_n, _c_n)

def test_mlkem_512_known_answer(self):
return self.generic_test_mlkem_known_answer(
ML_KEM128, "assets/kat_MLKEM_512.rsp"
)

def test_mlkem_768_known_answer(self):
return self.generic_test_mlkem_known_answer(
ML_KEM192, "assets/kat_MLKEM_768.rsp"
)

def test_mlkem_1024_known_answer(self):
return self.generic_test_mlkem_known_answer(
ML_KEM256, "assets/kat_MLKEM_1024.rsp"
)


if __name__ == "__main__":
unittest.main()