diff --git a/src/kyber_py/modules/modules.py b/src/kyber_py/modules/modules.py index 67b2e51..3a69e76 100644 --- a/src/kyber_py/modules/modules.py +++ b/src/kyber_py/modules/modules.py @@ -8,23 +8,20 @@ def __init__(self): self.matrix = MatrixKyber def decode_vector(self, input_bytes, k, d, is_ntt=False): - - if self.ring.n * d * k > len(input_bytes) * 8: - raise ValueError("Byte length is too short for given l") + # Ensure the input bytes are the correct length to create k elements with + # d bits used for each coefficient + if self.ring.n * d * k != len(input_bytes) * 8: + raise ValueError( + "Byte length is the wrong length for given k, d values" + ) # Bytes needed to decode a polynomial - chunk_length = 32 * d - - # Break input_bytes into blocks of length chunk_length - poly_bytes = [ - input_bytes[i : i + chunk_length] - for i in range(0, len(input_bytes), chunk_length) - ] + n = 32 * d - # Encode each chunk of bytes as a polynomial, we iterate only the first k elements in case we've - # been sent too many bytes to decode for the vector + # Encode each chunk of bytes as a polynomial and create the vector elements = [ - self.ring.decode(poly_bytes[i], d, is_ntt=is_ntt) for i in range(k) + self.ring.decode(input_bytes[i : i + n], d, is_ntt=is_ntt) + for i in range(0, len(input_bytes), n) ] return self.vector(elements)