Skip to content

Commit

Permalink
Optimize BDN Signature/Key Aggregation (#546)
Browse files Browse the repository at this point in the history
* Add BDN test fixtures

* Remove n^2 algorithm from signature/key aggregation

CountEnabled and IndexOfNthEnabled are both O(n) in the size of the
mask, making this loop n^2. The BLS operations still tend to be the slow
part, but the n^2 factor will start to show up with thousands of keys.

* Remove an unnecessary loop from hashPointToR

* Introduce a new CachedMask for BDN

This new mask will pre-compute reusable values, speeding up repeated
verification and aggregation of aggregate signatures (mostly the former).

* Ignore golangci lint

* Move Mask into BDN and remove the interface

* fix docs

Co-authored-by: AnomalRoil <[email protected]>

* Document mutability of Mask fields

---------

Co-authored-by: AnomalRoil <[email protected]>
  • Loading branch information
Stebalien and AnomalRoil authored Sep 24, 2024
1 parent a318fba commit 0ba2750
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 60 deletions.
84 changes: 37 additions & 47 deletions sign/bdn/bdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package bdn
import (
"crypto/cipher"
"errors"
"fmt"
"math/big"

"go.dedis.ch/kyber/v4"
Expand All @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI
// We also use the entire roster so that the coefficient will vary for the same
// public key used in different roster
func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) {
peers := make([][]byte, len(pubs))
for i, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}

peers[i] = peer
}

h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil)
if err != nil {
return nil, err
}

for _, peer := range peers {
_, err := h.Write(peer)
for _, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}
_, err = h.Write(peer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -128,62 +122,58 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {

// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
if len(sigs) != mask.CountEnabled() {
return nil, errors.New("length of signatures and public keys must match")
}

coefs, err := hashPointToR(mask.Publics())
if err != nil {
return nil, err
}

func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *Mask) (kyber.Point, error) {
agg := scheme.sigGroup.Point()
for i, buf := range sigs {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
// this should never happen as we check the lenths at the beginning
// an error here is probably a bug in the mask
return nil, errors.New("couldn't find the index")
for i := range mask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

if len(sigs) == 0 {
return nil, errors.New("length of signatures and public keys must match")
}

buf := sigs[0]
sigs = sigs[1:]

sig := scheme.sigGroup.Point()
err = sig.UnmarshalBinary(buf)
err := sig.UnmarshalBinary(buf)
if err != nil {
return nil, err
}

sigC := sig.Clone().Mul(coefs[peerIndex], sig)
sigC := sig.Clone().Mul(mask.publicCoefs[i], sig)
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
sigC = sigC.Add(sigC, sig)
agg = agg.Add(agg, sigC)
}

if len(sigs) > 0 {
return nil, errors.New("length of signatures and public keys must match")
}

return agg, nil
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: keyGroup -> R with R = {1, ..., 2^128}.
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
coefs, err := hashPointToR(mask.Publics())
if err != nil {
return nil, err
}

func (scheme *Scheme) AggregatePublicKeys(mask *Mask) (kyber.Point, error) {
agg := scheme.keyGroup.Point()
for i := 0; i < mask.CountEnabled(); i++ {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
for i := range mask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, errors.New("couldn't find the index")
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

pub := mask.Publics()[peerIndex]
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
pubC = pubC.Add(pubC, pub)
agg = agg.Add(agg, pubC)
agg = agg.Add(agg, mask.publicTerms[i])
}

return agg, nil
Expand Down Expand Up @@ -217,14 +207,14 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error {
// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128}
// Deprecated: use the new scheme methods instead.
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask)
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: G2 -> R with R = {1, ..., 2^128}.
// Deprecated: use the new scheme methods instead.
func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) {
func AggregatePublicKeys(suite pairing.Suite, mask *Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregatePublicKeys(mask)
}
110 changes: 104 additions & 6 deletions sign/bdn/bdn_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package bdn

import (
"encoding"
"encoding/hex"
"fmt"
"testing"

"github.com/stretchr/testify/require"
"go.dedis.ch/kyber/v4"
"go.dedis.ch/kyber/v4/pairing/bls12381/kilic"
"go.dedis.ch/kyber/v4/pairing/bn256"
"go.dedis.ch/kyber/v4/sign"
"go.dedis.ch/kyber/v4/sign/bls"
"go.dedis.ch/kyber/v4/util/random"
)
Expand All @@ -30,7 +32,7 @@ func TestBDN_HashPointToR_BN256(t *testing.T) {
require.Equal(t, "933f6013eb3f654f9489d6d45ad04eaf", coefs[2].String())
require.Equal(t, 16, coefs[0].MarshalSize())

mask, _ := sign.NewMask([]kyber.Point{p1, p2, p3}, nil)
mask, _ := NewMask([]kyber.Point{p1, p2, p3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)
mask.SetBit(2, true)
Expand All @@ -54,7 +56,7 @@ func TestBDN_AggregateSignatures(t *testing.T) {
sig2, err := Sign(suite, private2, msg)
require.NoError(t, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)

Expand Down Expand Up @@ -92,7 +94,7 @@ func TestBDN_SubsetSignature(t *testing.T) {
sig2, err := Sign(suite, private2, msg)
require.NoError(t, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public3, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public3, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(2, true)

Expand Down Expand Up @@ -131,7 +133,7 @@ func TestBDN_RogueAttack(t *testing.T) {
require.NoError(t, scheme.Verify(agg, msg, sig))

// New scheme that should detect
mask, _ := sign.NewMask(pubs, nil)
mask, _ := NewMask(pubs, nil)
mask.SetBit(0, true)
mask.SetBit(1, true)
agg, err = AggregatePublicKeys(suite, mask)
Expand All @@ -149,7 +151,7 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
sig2, err := Sign(suite, private2, msg)
require.Nil(b, err)

mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil)
mask, _ := NewMask([]kyber.Point{public1, public2}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)

Expand All @@ -158,3 +160,99 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
AggregateSignatures(suite, [][]byte{sig1, sig2}, mask)
}
}

func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) {
suite := kilic.NewBLS12381Suite()
schemeOnG2 := NewSchemeOnG2(suite)

rng := random.New()
pubKeys := make([]kyber.Point, 3000)
privKeys := make([]kyber.Scalar, 3000)
for i := range pubKeys {
privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng)
}

mask, err := NewMask(pubKeys, nil)
require.NoError(b, err)
for i := range pubKeys {
require.NoError(b, mask.SetBit(i, true))
}

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sigs := make([][]byte, len(privKeys))
for i, k := range privKeys {
s, err := schemeOnG2.Sign(k, msg)
require.NoError(b, err)
sigs[i] = s
}

sig, err := schemeOnG2.AggregateSignatures(sigs, mask)
require.NoError(b, err)
sigb, err := sig.MarshalBinary()
require.NoError(b, err)

b.ResetTimer()
for i := 0; i < b.N; i++ {
pk, err := schemeOnG2.AggregatePublicKeys(mask)
require.NoError(b, err)
require.NoError(b, schemeOnG2.Verify(pk, msg, sigb))
}
}

func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T {
t.Helper()
b, err := hex.DecodeString(s)
require.NoError(t, err)
require.NoError(t, into.UnmarshalBinary(b))
return into
}

// This tests exists to make sure we don't accidentally make breaking changes to signature
// aggregation by using checking against known aggregated signatures and keys.
func TestBDNFixtures(t *testing.T) {
suite := bn256.NewSuite()
schemeOnG1 := NewSchemeOnG1(suite)

public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493")
private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38")
public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f")
private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581")
public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec")
private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa")

sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8")
require.NoError(t, err)
sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484")
require.NoError(t, err)
sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195")
require.NoError(t, err)

aggSigExp := unmarshalHex(t, suite.G1().Point(), "43c1d2ad5a7d71a08f3cd7495db6b3c81a4547af1b76438b2f215e85ec178fea048f93f6ffed65a69ea757b47761e7178103bb347fd79689652e55b6e0054af2")
aggKeyExp := unmarshalHex(t, suite.G2().Point(), "43b5161ede207b9a69fc93114b0c5022b76cc22e813ba739c7e622d826b132333cd637505399963b94e393ec7f5d4875f82391620b34be1fde1f232204fa4f723935d4dbfb725f059456bcf2557f846c03190969f7b800e904d25b0b5bcbdd421c9877d443f0313c3425dfc1e7e646b665d27b9e649faadef1129f95670d70e1")

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sig1, err := schemeOnG1.Sign(private1, msg)
require.Nil(t, err)
require.Equal(t, sig1Exp, sig1)

sig2, err := schemeOnG1.Sign(private2, msg)
require.Nil(t, err)
require.Equal(t, sig2Exp, sig2)

sig3, err := schemeOnG1.Sign(private3, msg)
require.Nil(t, err)
require.Equal(t, sig3Exp, sig3)

mask, _ := NewMask([]kyber.Point{public1, public2, public3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)
mask.SetBit(2, true)

aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig3}, mask)
require.NoError(t, err)
require.True(t, aggSigExp.Equal(aggSig))

aggKey, err := schemeOnG1.AggregatePublicKeys(mask)
require.NoError(t, err)
require.True(t, aggKeyExp.Equal(aggKey))
}
55 changes: 52 additions & 3 deletions sign/mask.go → sign/bdn/mask.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
// Package sign contains useful tools for the different signing algorithms.
package sign
package bdn

import (
"errors"
"fmt"
"slices"

"go.dedis.ch/kyber/v4"
)

// Mask is a bitmask of the participation to a collective signature.
type Mask struct {
mask []byte
// The bitmask indicating which public keys are enabled/disabled for aggregation. This is
// the only mutable field.
mask []byte

// The following fields are immutable and should not be changed after the mask is created.
// They may be shared between multiple masks.

// Public keys for aggregation & signature verification.
publics []kyber.Point
// Coefficients used when aggregating signatures.
publicCoefs []kyber.Scalar
// Terms used to aggregate public keys
publicTerms []kyber.Point
}

// NewMask creates a new mask from a list of public keys. If a key is provided, it
// will set the bit of the key to 1 or return an error if it is not found.
//
// The returned Mask will contain pre-computed terms and coefficients for all provided public
// keys, so it should be re-used for optimal performance (e.g., by creating a "base" mask and
// cloning it whenever aggregating signatures and/or public keys).
func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
m := &Mask{
publics: publics,
Expand All @@ -33,6 +48,18 @@ func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) {
return nil, errors.New("key not found")
}

var err error
m.publicCoefs, err = hashPointToR(publics)
if err != nil {
return nil, fmt.Errorf("failed to hash public keys: %w", err)
}

m.publicTerms = make([]kyber.Point, len(publics))
for i, pub := range publics {
pubC := pub.Clone().Mul(m.publicCoefs[i], pub)
m.publicTerms[i] = pubC.Add(pubC, pub)
}

return m, nil
}

Expand All @@ -58,6 +85,17 @@ func (m *Mask) SetMask(mask []byte) error {
return nil
}

// GetBit returns true if the given bit is set.
func (m *Mask) GetBit(i int) (bool, error) {
if i >= len(m.publics) || i < 0 {
return false, errors.New("index out of range")
}

byteIndex := i / 8
mask := byte(1) << uint(i&7)
return m.mask[byteIndex]&mask != 0, nil
}

// SetBit turns on or off the bit at the given index.
func (m *Mask) SetBit(i int, enable bool) error {
if i >= len(m.publics) || i < 0 {
Expand Down Expand Up @@ -170,3 +208,14 @@ func (m *Mask) Merge(mask []byte) error {

return nil
}

// Clone copies the mask while keeping the precomputed coefficients, etc. This method is thread safe
// and does not modify the original mask. Modifications to the new Mask will not affect the original.
func (m *Mask) Clone() *Mask {
return &Mask{
mask: slices.Clone(m.mask),
publics: m.publics,
publicCoefs: m.publicCoefs,
publicTerms: m.publicTerms,
}
}
Loading

0 comments on commit 0ba2750

Please sign in to comment.