Skip to content

Commit

Permalink
Support for ML-DSA public key generation from private key (#2142)
Browse files Browse the repository at this point in the history
### Issues:
Resolves #CryptoAlg-2868

### Description of changes: 

It is often useful when serializing asymmetric key pairs to populate
both the public and private elements, given only the private element.
For this to be possible, an algorithm utility function is often provided
to derive key material. ML-DSA does not support this in the reference
implementation.

#### Background ML-DSA keypairs 

An ML-DSA private key is constructed of the following elements: (ref
https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.pdf)
```
sk = (
      rho, // public random seed (32-bytes)
      tr,  // public key hash (64-bytes)
      key, // private random seed (32-bytes) (utilized during sign)
      t0,  // polynomial vector: encodes the least significant bits of public-key polynomial t, facilitating certain computational efficiencies.
      s1,  // secret polynomial vectors. These vectors contain polynomials with coefficients in a specified range, 
      s2.  // serving as the secret components in the lattice-based structure of ML-DSA.
)
```

An ML-DSA public key is constructed of the following elements:
```
pk = (
      rho, // public random seed (32-bytes)
      t1.  // compressed representation of the public key polynomial 
)
```

- The vector t is decomposed into two parts:
- `t1`: Represents the higher-order bits of `t`.
- `t0`: Represents the lower-order bits of `t`.

One can see that to reconstruct the public key from the private key, one
must:
1. Extract all elements from `sk`, using the existing function in
`/ml_dsa_ref/packing.c`: `ml_dsa_unpack_sk`
    1. This will provide `sk = (rho, tr, key, t0, s1, s2)`.
2. Reconstruct `A` using `rho` with the existing function in
`/ml_dsa_ref/polyvec.c`: `ml_dsa_polyvec_matrix_expand`
3. Reconstruct `t` from `t = A*s1 + s2`
4. Drop `d` lower bits from `t` to get `t1`
5. Pack `rho`, `t1` into public key.
6. Verify `pk` matches expected value, by comparing SHAKE256(pk) + `tr`
(unpacked from secret key).

This has been implemented in `ml_dsa_pack_pk_from_sk` -- not tied to the
name, just using what I've seen so far in common nomenclature.

As the values of `d` differ for each parameter set of ML-DSA, we must
create packing functions for each parameter size. As such,
`ml_dsa_44_pack_pk_from_sk``, `ml_dsa_65_pack_pk_from_sk``, and
`ml_dsa_87_pack_pk_from_sk`` have been added to `ml_dsa.h` to serve as
utility functions in higher level EVP APIs.

### Call-outs:

The scope of this PR is only the algorithm level, using these functions
for useful tasks such as populating the public key automatically on
private key import -- will be added in subsequent PRs.

### Testing:
A new test has been added to `PQDSAParameterTest`, namely,
`KeyConsistencyTest` that will assert that packing the key is
successful, and that the key produced matches the original public key.

By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license and the ISC license.
  • Loading branch information
jakemas authored Jan 28, 2025
1 parent 37c2b5e commit 1f48000
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 10 deletions.
36 changes: 33 additions & 3 deletions crypto/evp_extra/p_pqdsa_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ struct PQDSATestVector {
const uint8_t *sig, size_t sig_len,
const uint8_t *message, size_t message_len,
const uint8_t *pre, size_t pre_len);

int (*pack_key)(uint8_t *public_key, const uint8_t *private_key);
};


Expand Down Expand Up @@ -1004,7 +1006,8 @@ static const struct PQDSATestVector parameterSet[] = {
1334,
ml_dsa_44_keypair_internal,
ml_dsa_44_sign_internal,
ml_dsa_44_verify_internal
ml_dsa_44_verify_internal,
ml_dsa_44_pack_pk_from_sk,
},
{
"MLDSA65",
Expand All @@ -1018,7 +1021,8 @@ static const struct PQDSATestVector parameterSet[] = {
1974,
ml_dsa_65_keypair_internal,
ml_dsa_65_sign_internal,
ml_dsa_65_verify_internal
ml_dsa_65_verify_internal,
ml_dsa_65_pack_pk_from_sk
},
{
"MLDSA87",
Expand All @@ -1032,7 +1036,8 @@ static const struct PQDSATestVector parameterSet[] = {
2614,
ml_dsa_87_keypair_internal,
ml_dsa_87_sign_internal,
ml_dsa_87_verify_internal
ml_dsa_87_verify_internal,
ml_dsa_87_pack_pk_from_sk
},
};

Expand Down Expand Up @@ -1516,6 +1521,31 @@ TEST_P(PQDSAParameterTest, ParsePublicKey) {
ASSERT_TRUE(pkey_from_der);
}

TEST_P(PQDSAParameterTest, KeyConsistencyTest) {
// This test: generates a random PQDSA key pair extracts the private key, and
// runs the public key calculator function to populate the coresponding public key.
// The test is sucessful when the calculated public key is equal to the original
// public key generated.

// ---- 1. Setup phase: generate a key and key buffers ----
int nid = GetParam().nid;
size_t pk_len = GetParam().public_key_len;
size_t sk_len = GetParam().private_key_len;

std::vector<uint8_t> pk(pk_len);
std::vector<uint8_t> sk(sk_len);
bssl::UniquePtr<EVP_PKEY> pkey(generate_key_pair(nid));

// ---- 2. Extract raw private key from the generated PKEY ----
EVP_PKEY_get_raw_private_key(pkey.get(), sk.data(), &sk_len);

// ---- 3. Generate a raw public key from the raw private key ----
ASSERT_TRUE(GetParam().pack_key(pk.data(), sk.data()));

// ---- 4. Generate a raw public key from the raw private key ----
CMP_VEC_AND_PKEY_PUBLIC(pk, pkey, pk_len);
}

// ML-DSA specific test framework to test pre-hash modes only applicable to ML-DSA
struct KnownMLDSA {
const char name[20];
Expand Down
25 changes: 24 additions & 1 deletion crypto/ml_dsa/ml_dsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ int ml_dsa_44_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key /* OUT */,
const uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_44_params_init(&params);
return ml_dsa_pack_pk_from_sk(&params, public_key, private_key) == 0;
}

int ml_dsa_44_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -145,6 +153,14 @@ int ml_dsa_65_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key /* OUT */,
const uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_65_params_init(&params);
return ml_dsa_pack_pk_from_sk(&params, public_key, private_key) == 0;
}

int ml_dsa_65_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -260,6 +276,14 @@ int ml_dsa_87_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key /* OUT */,
const uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_87_params_init(&params);
return ml_dsa_pack_pk_from_sk(&params, public_key, private_key) == 0;
}

int ml_dsa_87_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -367,4 +391,3 @@ int ml_dsa_extmu_87_verify_internal(const uint8_t *public_key /* IN */,
return ml_dsa_verify_internal(&params, sig, sig_len, mu, mu_len,
pre, pre_len, public_key, 1) == 0;
}

9 changes: 9 additions & 0 deletions crypto/ml_dsa/ml_dsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ extern "C" {
OPENSSL_EXPORT int ml_dsa_44_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_44_pack_pk_from_sk(uint8_t *public_key,
const uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_44_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down Expand Up @@ -80,6 +83,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_44_verify_internal(const uint8_t *public_key,
OPENSSL_EXPORT int ml_dsa_65_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_65_pack_pk_from_sk(uint8_t *public_key,
const uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_65_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down Expand Up @@ -127,6 +133,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_65_verify_internal(const uint8_t *public_key,
OPENSSL_EXPORT int ml_dsa_87_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_87_pack_pk_from_sk(uint8_t *public_key,
const uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_87_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down
69 changes: 63 additions & 6 deletions crypto/ml_dsa/ml_dsa_ref/packing.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,63 @@
#include "packing.h"
#include "polyvec.h"
#include "poly.h"
#include "../../fipsmodule/sha/internal.h"

/*************************************************
* Name: ml_dsa_pack_pk_from_sk
*
* Description: Takes a private key and constructs the corresponding public key.
* The hash of the contructed public key is then compared with
* the value of tr unpacked from the provided private key.
*
* Arguments: - ml_dsa_params: parameter struct
* - uint8_t pk: pointer to output byte array
* - const uint8_t sk: pointer to byte array containing bit-packed sk
*
* Returns 0 (when SHAKE256 hash of constructed pk matches tr)
**************************************************/
int ml_dsa_pack_pk_from_sk(ml_dsa_params *params,
uint8_t *pk,
const uint8_t *sk)
{
uint8_t rho[ML_DSA_SEEDBYTES];
uint8_t tr[ML_DSA_TRBYTES];
uint8_t tr_validate[ML_DSA_TRBYTES];
uint8_t key[ML_DSA_SEEDBYTES];
polyvecl mat[ML_DSA_K_MAX];
polyvecl s1;
polyveck s2, t1, t0;

//unpack sk
ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, sk);

// generate matrix A
ml_dsa_polyvec_matrix_expand(params, mat, rho);

// convert s1 into ntt representation
ml_dsa_polyvecl_ntt(params, &s1);

// construct t1 = A * s1
ml_dsa_polyvec_matrix_pointwise_montgomery(params, &t1, mat, &s1);

// reduce t1 modulo field
ml_dsa_polyveck_reduce(params, &t1);

// take t1 out of ntt representation
ml_dsa_polyveck_invntt_tomont(params, &t1);

// construct t1 = A * s1 + s2
ml_dsa_polyveck_add(params, &t1, &t1, &s2);

// cxtract t1 and write public key
ml_dsa_polyveck_caddq(params, &t1);
ml_dsa_polyveck_power2round(params, &t1, &t0, &t1);
ml_dsa_pack_pk(params, pk, rho, &t1);

// we hash pk to reproduce tr, check it with unpacked value to verify
SHAKE256(pk, params->public_key_bytes, tr_validate, ML_DSA_TRBYTES);
return OPENSSL_memcmp(tr_validate, tr, ML_DSA_TRBYTES);
}

/*************************************************
* Name: ml_dsa_pack_pk
Expand Down Expand Up @@ -122,12 +179,12 @@ void ml_dsa_pack_sk(ml_dsa_params *params,
* Unpack secret key sk = (rho, tr, key, t0, s1, s2).
*
* Arguments: - ml_dsa_params: parameter struct
* - const uint8_t rho[]: output byte array for rho
* - const uint8_t tr[]: output byte array for tr
* - const uint8_t key[]: output byte array for key
* - const polyveck *t0: pointer to output vector t0
* - const polyvecl *s1: pointer to output vector s1
* - const polyveck *s2: pointer to output vector s2
* - uint8_t rho[]: output byte array for rho
* - uint8_t tr[]: output byte array for tr
* - uint8_t key[]: output byte array for key
* - polyveck *t0: pointer to output vector t0
* - polyvecl *s1: pointer to output vector s1
* - polyveck *s2: pointer to output vector s2
* - uint8_t sk[]: pointer to byte array containing bit-packed sk
**************************************************/
void ml_dsa_unpack_sk(ml_dsa_params *params,
Expand Down
4 changes: 4 additions & 0 deletions crypto/ml_dsa/ml_dsa_ref/packing.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "params.h"
#include "polyvec.h"

int ml_dsa_pack_pk_from_sk(ml_dsa_params *params,
uint8_t *pk,
const uint8_t *sk);

void ml_dsa_pack_pk(ml_dsa_params *params,
uint8_t *pk,
const uint8_t rho[ML_DSA_SEEDBYTES],
Expand Down

0 comments on commit 1f48000

Please sign in to comment.