Skip to content

Commit

Permalink
*/resharing: allow resharing when more than T+1 of the old committee …
Browse files Browse the repository at this point in the history
…participates
  • Loading branch information
notatestuser committed Mar 30, 2020
1 parent a2e67ec commit 09e7f0d
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 32 deletions.
32 changes: 23 additions & 9 deletions ecdsa/resharing/local_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func NewLocalParty(
out chan<- tss.Message,
end chan<- keygen.LocalPartySaveData,
) tss.Party {
oldPartyCount := len(params.OldParties().IDs())
subset := key
if params.IsOldCommittee() {
subset = keygen.BuildLocalSaveDataSubset(key, params.OldParties().IDs())
Expand All @@ -84,11 +85,11 @@ func NewLocalParty(
end: end,
}
// msgs init
p.temp.dgRound1Messages = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee
p.temp.dgRound1Messages = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee
p.temp.dgRound2Message1s = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee
p.temp.dgRound2Message2s = make([]tss.ParsedMessage, params.NewPartyCount()) // "
p.temp.dgRound3Message1s = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee
p.temp.dgRound3Message2s = make([]tss.ParsedMessage, params.Threshold()+1) // "
p.temp.dgRound3Message1s = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee
p.temp.dgRound3Message2s = make([]tss.ParsedMessage, oldPartyCount) // "
p.temp.dgRound4Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee
// save data init
if key.LocalPreParams.ValidateWithProof() {
Expand Down Expand Up @@ -117,6 +118,25 @@ func (p *LocalParty) UpdateFromBytes(wireBytes []byte, from *tss.PartyID, isBroa
return p.Update(msg)
}

func (p *LocalParty) ValidateMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
if ok, err := p.BaseParty.ValidateMessage(msg); !ok || err != nil {
return ok, err
}
// check that the message's "from index" will fit into the array
var maxFromIdx int
switch msg.Content().(type) {
case *DGRound2Message1, *DGRound2Message2, *DGRound4Message:
maxFromIdx = len(p.params.NewParties().IDs()) - 1
default:
maxFromIdx = len(p.params.OldParties().IDs()) - 1
}
if maxFromIdx < msg.GetFrom().Index {
return false, p.WrapError(fmt.Errorf("received msg with a sender index too great (%d <= %d)",
maxFromIdx, msg.GetFrom().Index), msg.GetFrom())
}
return true, nil
}

func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
// ValidateBasic is cheap; double-check the message here in case the public StoreMessage was called externally
if ok, err := p.ValidateMessage(msg); !ok || err != nil {
Expand All @@ -129,22 +149,16 @@ func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
switch msg.Content().(type) {
case *DGRound1Message:
p.temp.dgRound1Messages[fromPIdx] = msg

case *DGRound2Message1:
p.temp.dgRound2Message1s[fromPIdx] = msg

case *DGRound2Message2:
p.temp.dgRound2Message2s[fromPIdx] = msg

case *DGRound3Message1:
p.temp.dgRound3Message1s[fromPIdx] = msg

case *DGRound3Message2:
p.temp.dgRound3Message2s[fromPIdx] = msg

case *DGRound4Message:
p.temp.dgRound4Messages[fromPIdx] = msg

default: // unrecognised message, just ignore!
common.Logger.Warningf("unrecognised message ignored: %v", msg)
return false, nil
Expand Down
4 changes: 2 additions & 2 deletions ecdsa/resharing/local_party_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func TestE2EConcurrent(t *testing.T) {
threshold, newThreshold := testThreshold, testThreshold

// PHASE: load keygen fixtures
firstPartyIdx := 5
oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+firstPartyIdx, firstPartyIdx)
firstPartyIdx, extraParties := 5, 1 // extra can be 0 to N-first
oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+extraParties+firstPartyIdx, firstPartyIdx)
assert.NoError(t, err, "should load keygen fixtures")

// PHASE: resharing
Expand Down
8 changes: 6 additions & 2 deletions ecdsa/resharing/round_1_old_step_1.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package resharing

import (
"errors"
"fmt"

"github.com/binance-chain/tss-lib/crypto"
"github.com/binance-chain/tss-lib/crypto/commitments"
Expand All @@ -20,7 +21,7 @@ import (
// round 1 represents round 1 of the keygen part of the GG18 ECDSA TSS spec (Gennaro, Goldfeder; 2018)
func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- keygen.LocalPartySaveData) tss.Round {
return &round1{
&base{params, temp, input, save, out, end, make([]bool, params.Threshold()+1), make([]bool, len(params.NewParties().IDs())), false, 1}}
&base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}}
}

func (round *round1) Start() *tss.Error {
Expand All @@ -42,8 +43,11 @@ func (round *round1) Start() *tss.Error {

// 1. PrepareForSigning() -> w_i
xi, ks, bigXj := round.input.Xi, round.input.Ks, round.input.BigXj
if round.Threshold()+1 > len(ks) {
return round.WrapError(fmt.Errorf("t+1=%d is not satisfied by the key count of %d", round.Threshold()+1, len(ks)), round.PartyID())
}
newKs := round.NewParties().IDs().Keys()
wi, _ := signing.PrepareForSigning(i, round.Threshold()+1, xi, ks, bigXj)
wi, _ := signing.PrepareForSigning(i, len(round.OldParties().IDs()), xi, ks, bigXj)

// 2.
vi, shares, err := vss.Create(round.NewThreshold(), wi, newKs)
Expand Down
6 changes: 3 additions & 3 deletions ecdsa/resharing/round_4_new_step_2.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ func (round *round4) Start() *tss.Error {

// 5-9.
modQ := common.ModInt(tss.EC().Params().N)
vjc := make([][]*crypto.ECPoint, round.Threshold()+1)
for j := 0; j <= round.Threshold(); j++ { // P1..P_t+1. Ps are indexed from 0 here
vjc := make([][]*crypto.ECPoint, len(round.OldParties().IDs()))
for j := 0; j <= len(vjc)-1; j++ { // P1..P_t+1. Ps are indexed from 0 here
// 6-7.
r1msg := round.temp.dgRound1Messages[j].Content().(*DGRound1Message)
r3msg2 := round.temp.dgRound3Message2s[j].Content().(*DGRound3Message2)
Expand Down Expand Up @@ -150,7 +150,7 @@ func (round *round4) Start() *tss.Error {
Vc := make([]*crypto.ECPoint, round.NewThreshold()+1)
for c := 0; c <= round.NewThreshold(); c++ {
Vc[c] = vjc[0][c]
for j := 1; j <= round.Threshold(); j++ {
for j := 1; j <= len(vjc)-1; j++ {
Vc[c], err = Vc[c].Add(vjc[j][c])
if err != nil {
return round.WrapError(errors2.Wrapf(err, "Vc[c].Add(vjc[j][c])"))
Expand Down
33 changes: 24 additions & 9 deletions eddsa/resharing/local_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func NewLocalParty(
out chan<- tss.Message,
end chan<- keygen.LocalPartySaveData,
) tss.Party {
oldPartyCount := len(params.OldParties().IDs())
subset := key
if params.IsOldCommittee() {
subset = keygen.BuildLocalSaveDataSubset(key, params.OldParties().IDs())
Expand All @@ -83,10 +84,10 @@ func NewLocalParty(
end: end,
}
// msgs init
p.temp.dgRound1Messages = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee
p.temp.dgRound2Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // "
p.temp.dgRound3Message1s = make([]tss.ParsedMessage, params.Threshold()+1) // from t+1 of Old Committee
p.temp.dgRound3Message2s = make([]tss.ParsedMessage, params.Threshold()+1) // "
p.temp.dgRound1Messages = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee
p.temp.dgRound2Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee
p.temp.dgRound3Message1s = make([]tss.ParsedMessage, oldPartyCount) // from t+1 of Old Committee
p.temp.dgRound3Message2s = make([]tss.ParsedMessage, oldPartyCount) // "
p.temp.dgRound4Messages = make([]tss.ParsedMessage, params.NewPartyCount()) // from n of New Committee

return p
Expand All @@ -112,6 +113,25 @@ func (p *LocalParty) UpdateFromBytes(wireBytes []byte, from *tss.PartyID, isBroa
return p.Update(msg)
}

func (p *LocalParty) ValidateMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
if ok, err := p.BaseParty.ValidateMessage(msg); !ok || err != nil {
return ok, err
}
// check that the message's "from index" will fit into the array
var maxFromIdx int
switch msg.Content().(type) {
case *DGRound2Message, *DGRound4Message:
maxFromIdx = len(p.params.NewParties().IDs()) - 1
default:
maxFromIdx = len(p.params.OldParties().IDs()) - 1
}
if maxFromIdx < msg.GetFrom().Index {
return false, p.WrapError(fmt.Errorf("received msg with a sender index too great (%d <= %d)",
maxFromIdx, msg.GetFrom().Index), msg.GetFrom())
}
return true, nil
}

func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
// ValidateBasic is cheap; double-check the message here in case the public StoreMessage was called externally
if ok, err := p.ValidateMessage(msg); !ok || err != nil {
Expand All @@ -124,19 +144,14 @@ func (p *LocalParty) StoreMessage(msg tss.ParsedMessage) (bool, *tss.Error) {
switch msg.Content().(type) {
case *DGRound1Message:
p.temp.dgRound1Messages[fromPIdx] = msg

case *DGRound2Message:
p.temp.dgRound2Messages[fromPIdx] = msg

case *DGRound3Message1:
p.temp.dgRound3Message1s[fromPIdx] = msg

case *DGRound3Message2:
p.temp.dgRound3Message2s[fromPIdx] = msg

case *DGRound4Message:
p.temp.dgRound4Messages[fromPIdx] = msg

default: // unrecognised message, just ignore!
common.Logger.Warningf("unrecognised message ignored: %v", msg)
return false, nil
Expand Down
4 changes: 2 additions & 2 deletions eddsa/resharing/local_party_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestE2EConcurrent(t *testing.T) {
threshold, newThreshold := testThreshold, testThreshold

// PHASE: load keygen fixtures
firstPartyIdx := 5
oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+firstPartyIdx, firstPartyIdx)
firstPartyIdx, extraParties := 5, 1 // // extra can be 0 to N-first
oldKeys, oldPIDs, err := keygen.LoadKeygenTestFixtures(testThreshold+1+extraParties+firstPartyIdx, firstPartyIdx)
assert.NoError(t, err, "should load keygen fixtures")

// PHASE: resharing
Expand Down
8 changes: 6 additions & 2 deletions eddsa/resharing/round_1_old_step_1.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package resharing

import (
"errors"
"fmt"

"github.com/binance-chain/tss-lib/crypto"
"github.com/binance-chain/tss-lib/crypto/commitments"
Expand All @@ -20,7 +21,7 @@ import (
// round 1 represents round 1 of the keygen part of the EDDSA TSS spec
func newRound1(params *tss.ReSharingParameters, input, save *keygen.LocalPartySaveData, temp *localTempData, out chan<- tss.Message, end chan<- keygen.LocalPartySaveData) tss.Round {
return &round1{
&base{params, temp, input, save, out, end, make([]bool, params.Threshold()+1), make([]bool, len(params.NewParties().IDs())), false, 1}}
&base{params, temp, input, save, out, end, make([]bool, len(params.OldParties().IDs())), make([]bool, len(params.NewParties().IDs())), false, 1}}
}

func (round *round1) Start() *tss.Error {
Expand All @@ -42,8 +43,11 @@ func (round *round1) Start() *tss.Error {

// 1. PrepareForSigning() -> w_i
xi, ks := round.input.Xi, round.input.Ks
if round.Threshold()+1 > len(ks) {
return round.WrapError(fmt.Errorf("t+1=%d is not satisfied by the key count of %d", round.Threshold()+1, len(ks)), round.PartyID())
}
newKs := round.NewParties().IDs().Keys()
wi := signing.PrepareForSigning(i, round.Threshold()+1, xi, ks)
wi := signing.PrepareForSigning(i, len(round.OldParties().IDs()), xi, ks)

// 2.
vi, shares, err := vss.Create(round.NewThreshold(), wi, newKs)
Expand Down
6 changes: 3 additions & 3 deletions eddsa/resharing/round_4_new_step_2.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func (round *round4) Start() *tss.Error {

// 2-8.
modQ := common.ModInt(tss.EC().Params().N)
vjc := make([][]*crypto.ECPoint, round.Threshold()+1)
for j := 0; j <= round.Threshold(); j++ { // P1..P_t+1. Ps are indexed from 0 here
vjc := make([][]*crypto.ECPoint, len(round.OldParties().IDs()))
for j := 0; j <= len(vjc)-1; j++ { // P1..P_t+1. Ps are indexed from 0 here
r1msg := round.temp.dgRound1Messages[j].Content().(*DGRound1Message)
r3msg2 := round.temp.dgRound3Message2s[j].Content().(*DGRound3Message2)

Expand Down Expand Up @@ -79,7 +79,7 @@ func (round *round4) Start() *tss.Error {
Vc := make([]*crypto.ECPoint, round.NewThreshold()+1)
for c := 0; c <= round.NewThreshold(); c++ {
Vc[c] = vjc[0][c]
for j := 1; j <= round.Threshold(); j++ {
for j := 1; j <= len(vjc)-1; j++ {
Vc[c], err = Vc[c].Add(vjc[j][c])
if err != nil {
return round.WrapError(errors.Wrapf(err, "Vc[c].Add(vjc[j][c])"))
Expand Down

0 comments on commit 09e7f0d

Please sign in to comment.