diff --git a/ecdsa/keygen/dln_verifier.go b/ecdsa/keygen/dln_verifier.go new file mode 100644 index 00000000..cc6be8bf --- /dev/null +++ b/ecdsa/keygen/dln_verifier.go @@ -0,0 +1,73 @@ +// Copyright © 2019 Binance +// +// This file is part of Binance. The full Binance copyright notice, including +// terms governing use, modification, and redistribution, is contained in the +// file LICENSE at the root of the source code distribution tree. + +package keygen + +import ( + "errors" + "math/big" + + "github.com/bnb-chain/tss-lib/crypto/dlnproof" +) + +type DlnProofVerifier struct { + semaphore chan interface{} +} + +type message interface { + UnmarshalDLNProof1() (*dlnproof.Proof, error) + UnmarshalDLNProof2() (*dlnproof.Proof, error) +} + +func NewDlnProofVerifier(concurrency int) *DlnProofVerifier { + if concurrency == 0 { + panic(errors.New("NewDlnProofverifier: concurrency level must not be zero")) + } + + semaphore := make(chan interface{}, concurrency) + + return &DlnProofVerifier{ + semaphore: semaphore, + } +} + +func (dpv *DlnProofVerifier) VerifyDLNProof1( + m message, + h1, h2, n *big.Int, + onDone func(bool), +) { + dpv.semaphore <- struct{}{} + go func() { + defer func() { <-dpv.semaphore }() + + dlnProof, err := m.UnmarshalDLNProof1() + if err != nil { + onDone(false) + return + } + + onDone(dlnProof.Verify(h1, h2, n)) + }() +} + +func (dpv *DlnProofVerifier) VerifyDLNProof2( + m message, + h1, h2, n *big.Int, + onDone func(bool), +) { + dpv.semaphore <- struct{}{} + go func() { + defer func() { <-dpv.semaphore }() + + dlnProof, err := m.UnmarshalDLNProof2() + if err != nil { + onDone(false) + return + } + + onDone(dlnProof.Verify(h1, h2, n)) + }() +} diff --git a/ecdsa/keygen/dln_verifier_test.go b/ecdsa/keygen/dln_verifier_test.go new file mode 100644 index 00000000..738bbf6b --- /dev/null +++ b/ecdsa/keygen/dln_verifier_test.go @@ -0,0 +1,241 @@ +// Copyright © 2019 Binance +// +// This file is part of Binance. The full Binance copyright notice, including +// terms governing use, modification, and redistribution, is contained in the +// file LICENSE at the root of the source code distribution tree. + +package keygen + +import ( + "math/big" + "runtime" + "testing" + + "github.com/bnb-chain/tss-lib/crypto/dlnproof" +) + +func BenchmarkDlnProof_Verify(b *testing.B) { + localPartySaveData, _, err := LoadKeygenTestFixtures(1) + if err != nil { + b.Fatal(err) + } + + params := localPartySaveData[0].LocalPreParams + + proof := dlnproof.NewDLNProof( + params.H1i, + params.H2i, + params.Alpha, + params.P, + params.Q, + params.NTildei, + ) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + proof.Verify(params.H1i, params.H2i, params.NTildei) + } +} + +func BenchmarkDlnVerifier_VerifyProof1(b *testing.B) { + preParams, proof := prepareProofB(b) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + resultChan := make(chan bool) + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + <-resultChan + } +} + +func BenchmarkDlnVerifier_VerifyProof2(b *testing.B) { + preParams, proof := prepareProofB(b) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + resultChan := make(chan bool) + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + <-resultChan + } +} + +func TestVerifyDLNProof1_Success(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if !success { + t.Fatal("expected positive verification") + } +} + +func TestVerifyDLNProof1_MalformedMessage(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof[:len(proof)-1], // truncate + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof1_IncorrectProof(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + wrongH1i := preParams.H1i.Sub(preParams.H1i, big.NewInt(1)) + verifier.VerifyDLNProof1(message, wrongH1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof2_Success(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if !success { + t.Fatal("expected positive verification") + } +} + +func TestVerifyDLNProof2_MalformedMessage(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof[:len(proof)-1], // truncate + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof2_IncorrectProof(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + wrongH2i := preParams.H2i.Add(preParams.H2i, big.NewInt(1)) + verifier.VerifyDLNProof2(message, preParams.H1i, wrongH2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func prepareProofT(t *testing.T) (*LocalPreParams, [][]byte) { + preParams, serialized, err := prepareProof() + if err != nil { + t.Fatal(err) + } + + return preParams, serialized +} + +func prepareProofB(b *testing.B) (*LocalPreParams, [][]byte) { + preParams, serialized, err := prepareProof() + if err != nil { + b.Fatal(err) + } + + return preParams, serialized +} + +func prepareProof() (*LocalPreParams, [][]byte, error) { + localPartySaveData, _, err := LoadKeygenTestFixtures(1) + if err != nil { + return nil, [][]byte{}, err + } + + preParams := localPartySaveData[0].LocalPreParams + + proof := dlnproof.NewDLNProof( + preParams.H1i, + preParams.H2i, + preParams.Alpha, + preParams.P, + preParams.Q, + preParams.NTildei, + ) + + serialized, err := proof.Serialize() + if err != nil { + if err != nil { + return nil, [][]byte{}, err + } + } + + return &preParams, serialized, nil +} diff --git a/ecdsa/keygen/round_1.go b/ecdsa/keygen/round_1.go index 8eea72d4..9cfda2df 100644 --- a/ecdsa/keygen/round_1.go +++ b/ecdsa/keygen/round_1.go @@ -74,7 +74,7 @@ func (round *round1) Start() *tss.Error { } else if round.save.LocalPreParams.ValidateWithProof() { preParams = &round.save.LocalPreParams } else { - preParams, err = GeneratePreParams(round.SafePrimeGenTimeout(), 3) + preParams, err = GeneratePreParams(round.SafePrimeGenTimeout(), round.Concurrency()) if err != nil { return round.WrapError(errors.New("pre-params generation failed"), Pi) } diff --git a/ecdsa/keygen/round_2.go b/ecdsa/keygen/round_2.go index b41ec78c..e72b6be2 100644 --- a/ecdsa/keygen/round_2.go +++ b/ecdsa/keygen/round_2.go @@ -9,9 +9,9 @@ package keygen import ( "encoding/hex" "errors" - "math/big" "sync" + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -27,6 +27,13 @@ func (round *round2) Start() *tss.Error { round.started = true round.resetOK() + common.Logger.Debugf( + "%s Setting up DLN verification with concurrency level of %d", + round.PartyID(), + round.Concurrency(), + ) + dlnVerifier := NewDlnProofVerifier(round.Concurrency()) + i := round.PartyID().Index // 6. verify dln proofs, store r1 message pieces, ensure uniqueness of h1j, h2j @@ -58,19 +65,23 @@ func (round *round2) Start() *tss.Error { return round.WrapError(errors.New("this h2j was already used by another party"), msg.GetFrom()) } h1H2Map[h1JHex], h1H2Map[h2JHex] = struct{}{}, struct{}{} + wg.Add(2) - go func(j int, msg tss.ParsedMessage, r1msg *KGRound1Message, H1j, H2j, NTildej *big.Int) { - if dlnProof1, err := r1msg.UnmarshalDLNProof1(); err != nil || !dlnProof1.Verify(H1j, H2j, NTildej) { - dlnProof1FailCulprits[j] = msg.GetFrom() + _j := j + _msg := msg + + dlnVerifier.VerifyDLNProof1(r1msg, H1j, H2j, NTildej, func(isValid bool) { + if !isValid { + dlnProof1FailCulprits[_j] = _msg.GetFrom() } wg.Done() - }(j, msg, r1msg, H1j, H2j, NTildej) - go func(j int, msg tss.ParsedMessage, r1msg *KGRound1Message, H1j, H2j, NTildej *big.Int) { - if dlnProof2, err := r1msg.UnmarshalDLNProof2(); err != nil || !dlnProof2.Verify(H2j, H1j, NTildej) { - dlnProof2FailCulprits[j] = msg.GetFrom() + }) + dlnVerifier.VerifyDLNProof2(r1msg, H2j, H1j, NTildej, func(isValid bool) { + if !isValid { + dlnProof2FailCulprits[_j] = _msg.GetFrom() } wg.Done() - }(j, msg, r1msg, H1j, H2j, NTildej) + }) } wg.Wait() for _, culprit := range append(dlnProof1FailCulprits, dlnProof2FailCulprits...) { diff --git a/ecdsa/resharing/round_2_new_step_1.go b/ecdsa/resharing/round_2_new_step_1.go index 365e20a3..c93c2861 100644 --- a/ecdsa/resharing/round_2_new_step_1.go +++ b/ecdsa/resharing/round_2_new_step_1.go @@ -49,7 +49,7 @@ func (round *round2) Start() *tss.Error { preParams = &round.save.LocalPreParams } else { var err error - preParams, err = keygen.GeneratePreParams(round.SafePrimeGenTimeout()) + preParams, err = keygen.GeneratePreParams(round.SafePrimeGenTimeout(), round.Concurrency()) if err != nil { return round.WrapError(errors.New("pre-params generation failed"), Pi) } diff --git a/ecdsa/resharing/round_4_new_step_2.go b/ecdsa/resharing/round_4_new_step_2.go index 40ef9448..9bff552d 100644 --- a/ecdsa/resharing/round_4_new_step_2.go +++ b/ecdsa/resharing/round_4_new_step_2.go @@ -18,6 +18,7 @@ import ( "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/crypto/commitments" "github.com/bnb-chain/tss-lib/crypto/vss" + "github.com/bnb-chain/tss-lib/ecdsa/keygen" "github.com/bnb-chain/tss-lib/tss" ) @@ -36,6 +37,13 @@ func (round *round4) Start() *tss.Error { return nil } + common.Logger.Debugf( + "%s Setting up DLN verification with concurrency level of %d", + round.PartyID(), + round.Concurrency(), + ) + dlnVerifier := keygen.NewDlnProofVerifier(round.Concurrency()) + Pi := round.PartyID() i := Pi.Index @@ -71,20 +79,22 @@ func (round *round4) Start() *tss.Error { } wg.Done() }(j, msg, r2msg1) - go func(j int, msg tss.ParsedMessage, r2msg1 *DGRound2Message1, H1j, H2j, NTildej *big.Int) { - if dlnProof1, err := r2msg1.UnmarshalDLNProof1(); err != nil || !dlnProof1.Verify(H1j, H2j, NTildej) { - dlnProof1FailCulprits[j] = msg.GetFrom() - common.Logger.Warningf("dln proof 1 verify failed for party %s", msg.GetFrom(), err) + _j := j + _msg := msg + dlnVerifier.VerifyDLNProof1(r2msg1, H1j, H2j, NTildej, func(isValid bool) { + if !isValid { + dlnProof1FailCulprits[_j] = _msg.GetFrom() + common.Logger.Warningf("dln proof 1 verify failed for party %s", _msg.GetFrom()) } wg.Done() - }(j, msg, r2msg1, H1j, H2j, NTildej) - go func(j int, msg tss.ParsedMessage, r2msg1 *DGRound2Message1, H1j, H2j, NTildej *big.Int) { - if dlnProof2, err := r2msg1.UnmarshalDLNProof2(); err != nil || !dlnProof2.Verify(H2j, H1j, NTildej) { - dlnProof2FailCulprits[j] = msg.GetFrom() - common.Logger.Warningf("dln proof 2 verify failed for party %s", msg.GetFrom(), err) + }) + dlnVerifier.VerifyDLNProof2(r2msg1, H2j, H1j, NTildej, func(isValid bool) { + if !isValid { + dlnProof2FailCulprits[_j] = _msg.GetFrom() + common.Logger.Warningf("dln proof 2 verify failed for party %s", _msg.GetFrom()) } wg.Done() - }(j, msg, r2msg1, H1j, H2j, NTildej) + }) } wg.Wait() for _, culprit := range append(append(paiProofCulprits, dlnProof1FailCulprits...), dlnProof2FailCulprits...) { diff --git a/tss/params.go b/tss/params.go index 8cb33368..8bf74148 100644 --- a/tss/params.go +++ b/tss/params.go @@ -8,7 +8,7 @@ package tss import ( "crypto/elliptic" - "errors" + "runtime" "time" ) @@ -19,6 +19,7 @@ type ( parties *PeerContext partyCount int threshold int + concurrency int safePrimeGenTimeout time.Duration } @@ -35,23 +36,15 @@ const ( ) // Exported, used in `tss` client -func NewParameters(ec elliptic.Curve, ctx *PeerContext, partyID *PartyID, partyCount, threshold int, optionalSafePrimeGenTimeout ...time.Duration) *Parameters { - var safePrimeGenTimeout time.Duration - if 0 < len(optionalSafePrimeGenTimeout) { - if 1 < len(optionalSafePrimeGenTimeout) { - panic(errors.New("GeneratePreParams: expected 0 or 1 item in `optionalSafePrimeGenTimeout`")) - } - safePrimeGenTimeout = optionalSafePrimeGenTimeout[0] - } else { - safePrimeGenTimeout = defaultSafePrimeGenTimeout - } +func NewParameters(ec elliptic.Curve, ctx *PeerContext, partyID *PartyID, partyCount, threshold int) *Parameters { return &Parameters{ ec: ec, parties: ctx, partyID: partyID, partyCount: partyCount, threshold: threshold, - safePrimeGenTimeout: safePrimeGenTimeout, + concurrency: runtime.GOMAXPROCS(0), + safePrimeGenTimeout: defaultSafePrimeGenTimeout, } } @@ -75,10 +68,23 @@ func (params *Parameters) Threshold() int { return params.threshold } +func (params *Parameters) Concurrency() int { + return params.concurrency +} + func (params *Parameters) SafePrimeGenTimeout() time.Duration { return params.safePrimeGenTimeout } +// The concurrency level must be >= 1. +func (params *Parameters) SetConcurrency(concurrency int) { + params.concurrency = concurrency +} + +func (params *Parameters) SetSafePrimeGenTimeout(timeout time.Duration) { + params.safePrimeGenTimeout = timeout +} + // ----- // // Exported, used in `tss` client