From 09e7f0d91f9f28344e424c63bcb89252e136e927 Mon Sep 17 00:00:00 2001 From: Luke Plaster Date: Mon, 30 Mar 2020 08:55:31 +0800 Subject: [PATCH] */resharing: allow resharing when more than T+1 of the old committee participates --- ecdsa/resharing/local_party.go | 32 ++++++++++++++++++-------- ecdsa/resharing/local_party_test.go | 4 ++-- ecdsa/resharing/round_1_old_step_1.go | 8 +++++-- ecdsa/resharing/round_4_new_step_2.go | 6 ++--- eddsa/resharing/local_party.go | 33 +++++++++++++++++++-------- eddsa/resharing/local_party_test.go | 4 ++-- eddsa/resharing/round_1_old_step_1.go | 8 +++++-- eddsa/resharing/round_4_new_step_2.go | 6 ++--- 8 files changed, 69 insertions(+), 32 deletions(-) diff --git a/ecdsa/resharing/local_party.go b/ecdsa/resharing/local_party.go index fa5f6a36..3b6ddaf0 100644 --- a/ecdsa/resharing/local_party.go +++ b/ecdsa/resharing/local_party.go @@ -70,6 +70,7 @@ func NewLocalParty( out chan<- tss.Message, end chan<- keygen.LocalPartySaveData, ) tss.Party { + oldPartyCount := len(params.OldParties().IDs()) subset := key if params.IsOldCommittee() { subset = keygen.BuildLocalSaveDataSubset(key, params.OldParties().IDs()) @@ -84,11 +85,11 @@ func NewLocalParty( end: end, } // msgs init - p.temp.dgRound1Messages = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee + p.temp.dgRound1Messages = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee p.temp.dgRound2Message1s = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee p.temp.dgRound2Message2s = make([]tss.ParsedMessage, params.NewPartyCount()) // " - p.temp.dgRound3Message1s = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee - p.temp.dgRound3Message2s = make([]tss.ParsedMessage, params.Threshold()+1) // " + p.temp.dgRound3Message1s = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee + p.temp.dgRound3Message2s = make([]tss.ParsedMessage, oldPartyCount) // " p.temp.dgRound4Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee // save data init if key.LocalPreParams.ValidateWithProof() { @@ -117,6 +118,25 @@ func (p *LocalParty) UpdateFromBytes(wireBytes []byte, from *tss.PartyID, isBroa return p.Update(msg) } +func (p *LocalParty) ValidateMessage(msg tss.ParsedMessage) (bool, *tss.Error) { + if ok, err := p.BaseParty.ValidateMessage(msg); !ok || err != nil { + return ok, err + } + // check that the message's "from index" will fit into the array + var maxFromIdx int + switch msg.Content().(type) { + case *DGRound2Message1, *DGRound2Message2, *DGRound4Message: + maxFromIdx = len(p.params.NewParties().IDs()) - 1 + default: + maxFromIdx = len(p.params.OldParties().IDs()) - 1 + } + if maxFromIdx < msg.GetFrom().Index { + return false, p.WrapError(fmt.Errorf("received msg with a sender index too great (%d <= %d)", + maxFromIdx, msg.GetFrom().Index), msg.GetFrom()) + } + return true, nil +} + func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) { // ValidateBasic is cheap; double-check the message here in case the public StoreMessage was called externally if ok, err := p.ValidateMessage(msg); !ok || err != nil { @@ -129,22 +149,16 @@ func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) { switch msg.Content().(type) { case *DGRound1Message: p.temp.dgRound1Messages[fromPIdx] = msg - case *DGRound2Message1: p.temp.dgRound2Message1s[fromPIdx] = msg - case *DGRound2Message2: p.temp.dgRound2Message2s[fromPIdx] = msg - case *DGRound3Message1: p.temp.dgRound3Message1s[fromPIdx] = msg - case *DGRound3Message2: p.temp.dgRound3Message2s[fromPIdx] = msg - case *DGRound4Message: p.temp.dgRound4Messages[fromPIdx] = msg - default: // unrecognised message, just ignore! common.Logger.Warningf("unrecognised message ignored: %v", msg) return false, nil diff --git a/ecdsa/resharing/local_party_test.go b/ecdsa/resharing/local_party_test.go index d7cc82c9..35395b41 100644 --- a/ecdsa/resharing/local_party_test.go +++ b/ecdsa/resharing/local_party_test.go @@ -45,8 +45,8 @@ func TestE2EConcurrent(t *testing.T) { threshold, newThreshold := testThreshold, testThreshold // PHASE: load keygen fixtures - firstPartyIdx := 5 - oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+firstPartyIdx, firstPartyIdx) + firstPartyIdx, extraParties := 5, 1 // extra can be 0 to N-first + oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+extraParties+firstPartyIdx, firstPartyIdx) assert.NoError(t, err, "should load keygen fixtures") // PHASE: resharing diff --git a/ecdsa/resharing/round_1_old_step_1.go b/ecdsa/resharing/round_1_old_step_1.go index 248ab1c4..868c21ff 100644 --- a/ecdsa/resharing/round_1_old_step_1.go +++ b/ecdsa/resharing/round_1_old_step_1.go @@ -8,6 +8,7 @@ package resharing import ( "errors" + "fmt" "github.com/binance-chain/tss-lib/crypto" "github.com/binance-chain/tss-lib/crypto/commitments" @@ -20,7 +21,7 @@ import ( // round 1 represents round 1 of the keygen part of the GG18 ECDSA TSS spec (Gennaro, Goldfeder; 2018) func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- keygen.LocalPartySaveData) tss.Round { return &round1{ - &base{params, temp, input, save, out, end, make([]bool, params.Threshold()+1), make([]bool, len(params.NewParties().IDs())), false, 1}} + &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}} } func (round *round1) Start() *tss.Error { @@ -42,8 +43,11 @@ func (round *round1) Start() *tss.Error { // 1. PrepareForSigning() -> w_i xi, ks, bigXj := round.input.Xi, round.input.Ks, round.input.BigXj + if round.Threshold()+1 > len(ks) { + return round.WrapError(fmt.Errorf("t+1=%d is not satisfied by the key count of %d", round.Threshold()+1, len(ks)), round.PartyID()) + } newKs := round.NewParties().IDs().Keys() - wi, _ := signing.PrepareForSigning(i, round.Threshold()+1, xi, ks, bigXj) + wi, _ := signing.PrepareForSigning(i, len(round.OldParties().IDs()), xi, ks, bigXj) // 2. vi, shares, err := vss.Create(round.NewThreshold(), wi, newKs) diff --git a/ecdsa/resharing/round_4_new_step_2.go b/ecdsa/resharing/round_4_new_step_2.go index 3cc658d0..f8247260 100644 --- a/ecdsa/resharing/round_4_new_step_2.go +++ b/ecdsa/resharing/round_4_new_step_2.go @@ -108,8 +108,8 @@ func (round *round4) Start() *tss.Error { // 5-9. modQ := common.ModInt(tss.EC().Params().N) - vjc := make([][]*crypto.ECPoint, round.Threshold()+1) - for j := 0; j <= round.Threshold(); j++ { // P1..P_t+1. Ps are indexed from 0 here + vjc := make([][]*crypto.ECPoint, len(round.OldParties().IDs())) + for j := 0; j <= len(vjc)-1; j++ { // P1..P_t+1. Ps are indexed from 0 here // 6-7. r1msg := round.temp.dgRound1Messages[j].Content().(*DGRound1Message) r3msg2 := round.temp.dgRound3Message2s[j].Content().(*DGRound3Message2) @@ -150,7 +150,7 @@ func (round *round4) Start() *tss.Error { Vc := make([]*crypto.ECPoint, round.NewThreshold()+1) for c := 0; c <= round.NewThreshold(); c++ { Vc[c] = vjc[0][c] - for j := 1; j <= round.Threshold(); j++ { + for j := 1; j <= len(vjc)-1; j++ { Vc[c], err = Vc[c].Add(vjc[j][c]) if err != nil { return round.WrapError(errors2.Wrapf(err, "Vc[c].Add(vjc[j][c])")) diff --git a/eddsa/resharing/local_party.go b/eddsa/resharing/local_party.go index bcbbaada..2b7b7469 100644 --- a/eddsa/resharing/local_party.go +++ b/eddsa/resharing/local_party.go @@ -69,6 +69,7 @@ func NewLocalParty( out chan<- tss.Message, end chan<- keygen.LocalPartySaveData, ) tss.Party { + oldPartyCount := len(params.OldParties().IDs()) subset := key if params.IsOldCommittee() { subset = keygen.BuildLocalSaveDataSubset(key, params.OldParties().IDs()) @@ -83,10 +84,10 @@ func NewLocalParty( end: end, } // msgs init - p.temp.dgRound1Messages = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee - p.temp.dgRound2Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // " - p.temp.dgRound3Message1s = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee - p.temp.dgRound3Message2s = make([]tss.ParsedMessage, params.Threshold()+1) // " + p.temp.dgRound1Messages = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee + p.temp.dgRound2Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee + p.temp.dgRound3Message1s = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee + p.temp.dgRound3Message2s = make([]tss.ParsedMessage, oldPartyCount) // " p.temp.dgRound4Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee return p @@ -112,6 +113,25 @@ func (p *LocalParty) UpdateFromBytes(wireBytes []byte, from *tss.PartyID, isBroa return p.Update(msg) } +func (p *LocalParty) ValidateMessage(msg tss.ParsedMessage) (bool, *tss.Error) { + if ok, err := p.BaseParty.ValidateMessage(msg); !ok || err != nil { + return ok, err + } + // check that the message's "from index" will fit into the array + var maxFromIdx int + switch msg.Content().(type) { + case *DGRound2Message, *DGRound4Message: + maxFromIdx = len(p.params.NewParties().IDs()) - 1 + default: + maxFromIdx = len(p.params.OldParties().IDs()) - 1 + } + if maxFromIdx < msg.GetFrom().Index { + return false, p.WrapError(fmt.Errorf("received msg with a sender index too great (%d <= %d)", + maxFromIdx, msg.GetFrom().Index), msg.GetFrom()) + } + return true, nil +} + func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) { // ValidateBasic is cheap; double-check the message here in case the public StoreMessage was called externally if ok, err := p.ValidateMessage(msg); !ok || err != nil { @@ -124,19 +144,14 @@ func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) { switch msg.Content().(type) { case *DGRound1Message: p.temp.dgRound1Messages[fromPIdx] = msg - case *DGRound2Message: p.temp.dgRound2Messages[fromPIdx] = msg - case *DGRound3Message1: p.temp.dgRound3Message1s[fromPIdx] = msg - case *DGRound3Message2: p.temp.dgRound3Message2s[fromPIdx] = msg - case *DGRound4Message: p.temp.dgRound4Messages[fromPIdx] = msg - default: // unrecognised message, just ignore! common.Logger.Warningf("unrecognised message ignored: %v", msg) return false, nil diff --git a/eddsa/resharing/local_party_test.go b/eddsa/resharing/local_party_test.go index 90d721cb..3a33db37 100644 --- a/eddsa/resharing/local_party_test.go +++ b/eddsa/resharing/local_party_test.go @@ -43,8 +43,8 @@ func TestE2EConcurrent(t *testing.T) { threshold, newThreshold := testThreshold, testThreshold // PHASE: load keygen fixtures - firstPartyIdx := 5 - oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+firstPartyIdx, firstPartyIdx) + firstPartyIdx, extraParties := 5, 1 // // extra can be 0 to N-first + oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+extraParties+firstPartyIdx, firstPartyIdx) assert.NoError(t, err, "should load keygen fixtures") // PHASE: resharing diff --git a/eddsa/resharing/round_1_old_step_1.go b/eddsa/resharing/round_1_old_step_1.go index d8a6f2db..56fb4351 100644 --- a/eddsa/resharing/round_1_old_step_1.go +++ b/eddsa/resharing/round_1_old_step_1.go @@ -8,6 +8,7 @@ package resharing import ( "errors" + "fmt" "github.com/binance-chain/tss-lib/crypto" "github.com/binance-chain/tss-lib/crypto/commitments" @@ -20,7 +21,7 @@ import ( // round 1 represents round 1 of the keygen part of the EDDSA TSS spec func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- keygen.LocalPartySaveData) tss.Round { return &round1{ - &base{params, temp, input, save, out, end, make([]bool, params.Threshold()+1), make([]bool, len(params.NewParties().IDs())), false, 1}} + &base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}} } func (round *round1) Start() *tss.Error { @@ -42,8 +43,11 @@ func (round *round1) Start() *tss.Error { // 1. PrepareForSigning() -> w_i xi, ks := round.input.Xi, round.input.Ks + if round.Threshold()+1 > len(ks) { + return round.WrapError(fmt.Errorf("t+1=%d is not satisfied by the key count of %d", round.Threshold()+1, len(ks)), round.PartyID()) + } newKs := round.NewParties().IDs().Keys() - wi := signing.PrepareForSigning(i, round.Threshold()+1, xi, ks) + wi := signing.PrepareForSigning(i, len(round.OldParties().IDs()), xi, ks) // 2. vi, shares, err := vss.Create(round.NewThreshold(), wi, newKs) diff --git a/eddsa/resharing/round_4_new_step_2.go b/eddsa/resharing/round_4_new_step_2.go index cd322203..fccbaf51 100644 --- a/eddsa/resharing/round_4_new_step_2.go +++ b/eddsa/resharing/round_4_new_step_2.go @@ -41,8 +41,8 @@ func (round *round4) Start() *tss.Error { // 2-8. modQ := common.ModInt(tss.EC().Params().N) - vjc := make([][]*crypto.ECPoint, round.Threshold()+1) - for j := 0; j <= round.Threshold(); j++ { // P1..P_t+1. Ps are indexed from 0 here + vjc := make([][]*crypto.ECPoint, len(round.OldParties().IDs())) + for j := 0; j <= len(vjc)-1; j++ { // P1..P_t+1. Ps are indexed from 0 here r1msg := round.temp.dgRound1Messages[j].Content().(*DGRound1Message) r3msg2 := round.temp.dgRound3Message2s[j].Content().(*DGRound3Message2) @@ -79,7 +79,7 @@ func (round *round4) Start() *tss.Error { Vc := make([]*crypto.ECPoint, round.NewThreshold()+1) for c := 0; c <= round.NewThreshold(); c++ { Vc[c] = vjc[0][c] - for j := 1; j <= round.Threshold(); j++ { + for j := 1; j <= len(vjc)-1; j++ { Vc[c], err = Vc[c].Add(vjc[j][c]) if err != nil { return round.WrapError(errors.Wrapf(err, "Vc[c].Add(vjc[j][c])"))