Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rsa support #12

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
39 changes: 39 additions & 0 deletions examples/simple_rsarep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Proof of knowledge of a discrete logarithm in a subgroup of an RSA group:
PK{ (x): y = x * g}.

Here, the group operation is written additively, so x * g is g multiplied by itself x times.
"""

from petlib.bn import Bn

from zksk import Secret, DLRep
from zksk.rsa_group import rsa_dlrep_trusted_setup

# Create a generator for a subgroup of an RSA group.
[g] = rsa_dlrep_trusted_setup(bits=1024, num=1)

# Preparing the secret.
# In practice, this should probably be a big integer (petlib.bn.Bn)
x = Secret()

# Setup the proof statement.

# First, compute the "left-hand side".
y = 3 * g

# Next, create the proof statement.
stmt = DLRep(y, x * g)

# Simulate the prover and the verifier interacting.
prover = stmt.get_prover({x: 3})
verifier = stmt.get_verifier()

commitment = prover.commit()
challenge = verifier.send_challenge(commitment)
response = prover.compute_response(challenge)
assert verifier.verify(response)

# Non-interactive proof.
nizk = stmt.prove({x: 3})
assert stmt.verify(nizk)
2 changes: 1 addition & 1 deletion tests/test_pairings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_additive_point(bp_group, group_pair):
assert AdditivePoint(g ** (g.group.order()), group_pair) == group_pair.GT.infinite()

r = bp_group.order().random()
g1, g1mg = g ** r, r * gmg
g1, g1mg = g**r, r * gmg
assert g1 == g1mg.pt
assert g1 * g1 * g1 == (g1mg + g1mg + g1mg).pt
assert g1.export() == g1mg.export()
Expand Down
56 changes: 56 additions & 0 deletions tests/test_rsagroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from petlib.bn import Bn

from zksk import Secret
from zksk.primitives.dlrep import DLRep
from zksk.utils.debug import SigmaProtocol
from zksk.rsa_group import RSAGroup, IntPt, rsa_dlrep_trusted_setup


def test_rsagroup_interactive_1():
[g, h] = rsa_dlrep_trusted_setup(bits=1024, num=2)
n = g.group.modulus
sk1, sk2 = n.random(), n.random()
pk = sk1 * g + sk2 * h

x1 = Secret()
x2 = Secret()
p = DLRep(pk, x1 * g + x2 * h)
prover = p.get_prover({x1: sk1, x2: sk2})
verifier = p.get_verifier()
protocol = SigmaProtocol(verifier, prover)
assert protocol.verify()


def test_rsagroup_and_interactive_1():
[g, h] = rsa_dlrep_trusted_setup(bits=1024, num=2)
n = g.group.modulus
sk1, sk2 = n.random(), n.random()
pk1 = sk1 * g
pk2 = sk2 * h

x1 = Secret()
x2 = Secret()
p = DLRep(pk1, x1 * g) & DLRep(pk2, x2 * h)
prover = p.get_prover({x1: sk1, x2: sk2})
verifier = p.get_verifier()
protocol = SigmaProtocol(verifier, prover)
assert protocol.verify()


def test_rsagroup_or_interactive_1():
[g, h] = rsa_dlrep_trusted_setup(bits=1024, num=2)
n = g.group.modulus
sk1, sk2 = n.random(), n.random()
pk1 = sk1 * g
pk2 = IntPt(Bn(1), RSAGroup(n))

x1 = Secret()
x2 = Secret()
p = DLRep(pk1, x1 * g) | DLRep(pk2, x2 * h)
p.subproofs[1].set_simulated()
prover = p.get_prover({x1: sk1, x2: sk2})
verifier = p.get_verifier()
protocol = SigmaProtocol(verifier, prover)
assert protocol.verify()
40 changes: 30 additions & 10 deletions zksk/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from zksk.consts import CHALLENGE_LENGTH
from zksk.base import Prover, Verifier, SimulationTranscript
from zksk.expr import Secret, update_secret_values
from zksk.rsa_group import IntPt
from zksk.utils import get_random_num, sum_bn_array
from zksk.utils.misc import get_default_attr
from zksk.exceptions import StatementSpecError, StatementMismatch
Expand Down Expand Up @@ -372,14 +373,22 @@ def validate_group_orders(self):
# the same group
for (word, gen_idx) in mydict.items():
# Word is the key, gen_idx is the value = a list of indices
ref_order = bases[gen_idx[0]].group.order()

for index in gen_idx:
if bases[index].group.order() != ref_order:
raise GroupMismatchError(
"A shared secret has bases which yield different group orders: %s"
% word
)
if isinstance(bases[0], IntPt):
ref_modulus = bases[gen_idx[0]].group.modulus
for index in gen_idx:
if bases[index].group.modulus != ref_modulus:
raise GroupMismatchError(
"A shared secret has bases which yield different group orders: %s"
% word
)
else:
ref_order = bases[gen_idx[0]].group.order()
for index in gen_idx:
if bases[index].group.order() != ref_order:
raise GroupMismatchError(
"A shared secret has bases which yield different group orders: %s"
% word
)

def get_proof_id(self, secret_id_map=None):
secret_vars = self.get_secret_vars()
Expand Down Expand Up @@ -770,8 +779,19 @@ def get_randomizers(self):
dict_name_gen = {s: g for s, g in zip(self.get_secret_vars(), self.get_bases())}

# Pair each Secret to a randomizer.
for u in dict_name_gen:
random_vals[u] = dict_name_gen[u].group.order().random()

if isinstance(self.get_bases()[0], IntPt):
rand_range = self.get_bases()[0].group.modulus * pow(
2, 2 * CHALLENGE_LENGTH
)
for u in dict_name_gen:
val = rand_range.random()
if Bn(2).random() == 0:
val = -val
random_vals[u] = val
else:
for u in dict_name_gen:
random_vals[u] = dict_name_gen[u].group.order().random()

return random_vals

Expand Down
36 changes: 28 additions & 8 deletions zksk/primitives/dlrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

"""
from hashlib import sha256
from black import out

from petlib.bn import Bn
from petlib.ec import EcGroup

from zksk.base import Verifier, Prover, SimulationTranscript
from zksk.expr import Secret, Expression
from zksk.rsa_group import RSAGroup
from zksk.utils import get_random_num
from zksk.consts import CHALLENGE_LENGTH
from zksk.composition import ComposableProofStmt
Expand Down Expand Up @@ -164,9 +167,18 @@ def get_randomizers(self):
of the proof.
"""
output = {}
order = self.bases[0].group.order()
for sec in set(self.secret_vars):
output.update({sec: order.random()})
if isinstance(self.bases[0].group, RSAGroup):
rand_range = self.bases[0].group.modulus * pow(2, 2 * CHALLENGE_LENGTH)
for sec in set(self.secret_vars):
val = rand_range.random()
if Bn(2).random() == 0:
val = -val
output.update({sec: val})
else:
order = self.bases[0].group.order()
for sec in set(self.secret_vars):
output.update({sec: order.random()})

return output

def recompute_commitment(self, challenge, responses):
Expand Down Expand Up @@ -240,9 +252,17 @@ def compute_response(self, challenge):
Returns:
A list of responses
"""
order = self.stmt.bases[0].group.order()
resps = [
(self.secret_values[self.stmt.secret_vars[i]] * challenge + k) % order
for i, k in enumerate(self.ks)
]
resps = []
if isinstance(self.stmt.bases[0].group, RSAGroup):
resps = [
(self.secret_values[self.stmt.secret_vars[i]] * challenge + k)
for i, k in enumerate(self.ks)
]
else:
order = self.stmt.bases[0].group.order()
resps = [
(self.secret_values[self.stmt.secret_vars[i]] * challenge + k) % order
for i, k in enumerate(self.ks)
]

return resps
102 changes: 102 additions & 0 deletions zksk/rsa_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Allow zero knowledge proofs in subgroups of RSA groups (groups of integers modulo the product of two safe primes), instead of only in groups of prime order.

Example:
PK{(alpha): y = alpha * g}
where y, and g are elements of a subgroup of an RSA group of order p * q, for two safe primes p and q. We assume there is a trusted setup which keeps p and q secret from both the Prover and the Verifier, which sends y, g, and p * q to both parties, and which sends alpha to the Prover.
We use additive notation for the group operation, so alpha * g is g raised to the exponent alpha.

The protocol we follow for proofs of discrete logarithm representations is inspired from page 34 of Boneh, Bünz and Fisch, Batching Techniques for Accumulators with Applications to IOPs and Stateless Blockchains, Crypto 2019.
"""
# To do: Allow ZKPs of other cryptographic primitives in RSA groups.
import math

from petlib.bn import Bn
from petlib.pack import *

# This sets up the RSA group and the subgroup generators
# Example:
# [g,h] = rsa_dlrep_trusted_setup(bits=1024,num = 2)
# g and h are two generators of the subgroup of quadratic residues of an RSA group of order the product of two 1024 bit primes.
def rsa_dlrep_trusted_setup(bits=1024, num=1):
p = Bn.get_prime(bits, safe=1)
q = Bn.get_prime(bits, safe=1)
n = p * q
b = n.num_bits()
while True:
q = Bn.from_num(Bn(2).pow(bits).random())
if q < n and math.gcd(int(q), int(n)) == 1:
break
g = IntPt((q * q) % n, RSAGroup(n))
res = [g]

num -= 1
while num != 0:
res.append(((p - 1) * (q - 1)).random() * g)
num -= 1
return res


# This class mimics petlib.ec.EcGroup, but for RSA groups.
class RSAGroup:
# Must take a Bignum as argument
def __init__(self, modulus):
self.modulus = modulus

def infinite(self):
return IntPt(1, self)

def wsum(self, weights, elems):
res = IntPt(Bn(1), self)
for i in range(0, len(elems)):
res = res + (weights[i] * elems[i])
return res

def __eq__(self, other):
return self.modulus == other.modulus


# This class mimics petlib.ec.EcPt, but for elements of RSA groups.
class IntPt:
# Must take one bignum and one RSAGroup as arguments
def __init__(self, value, modulus):
self.pt = value
self.group = modulus

# We use additive notation for the group operation, so IntPt.__add__ is actually multiplication.
def __add__(self, o):
return IntPt((self.pt * o.pt) % self.group.modulus, self.group)

# Similarly, IntPt.__rmul__ is actually exponentiation.
def __rmul__(self, o):
if o < 0:
return IntPt(
pow(self.pt.mod_inverse(self.group.modulus), -o, self.group.modulus),
self.group,
)
else:
return IntPt(pow(self.pt, o, self.group.modulus), self.group)

def __eq__(self, other):
return (self.pt == other.pt) and (self.group == other.group)


def enc_RSAGroup(obj):
return encode(obj.modulus)


def dec_RSAGroup(data):
return RSAGroup(decode(data))


def enc_IntPt(obj):
return encode([obj.pt, obj.group])


def dec_IntPt(data):
d = decode(data)
return IntPt(d[0], d[1])


register_coders(RSAGroup, 10, enc_RSAGroup, dec_RSAGroup)
register_coders(IntPt, 11, enc_IntPt, dec_IntPt)