diff --git a/relay/config.go b/relay/config.go index 47e49e0..777e762 100644 --- a/relay/config.go +++ b/relay/config.go @@ -10,6 +10,8 @@ import ( "github.com/hyperledger-labs/yui-relayer/core" ) +const DefaultMessageAggregationBatchSize = 8 + var _ core.ProverConfig = (*ProverConfig)(nil) var _ codectypes.UnpackInterfacesMessage = (*ProverConfig)(nil) @@ -43,6 +45,14 @@ func (pc ProverConfig) GetMrenclave() []byte { return mrenclave } +func (pc ProverConfig) GetMessageAggregationBatchSize() uint64 { + if pc.MessageAggregationBatchSize == 0 { + return DefaultMessageAggregationBatchSize + } else { + return pc.MessageAggregationBatchSize + } +} + func (pc ProverConfig) Validate() error { // origin prover config validation if err := pc.OriginProver.GetCachedValue().(core.ProverConfig).Validate(); err != nil { @@ -60,8 +70,8 @@ func (pc ProverConfig) Validate() error { if pc.KeyExpiration == 0 { return fmt.Errorf("KeyExpiration must be greater than 0") } - if pc.MessageAggregation && pc.MessageAggregationBatchSize < 2 { - return fmt.Errorf("MessageAggregationBatchSize must be greater than 1 if MessageAggregation is true") + if pc.MessageAggregation && pc.MessageAggregationBatchSize == 1 { + return fmt.Errorf("MessageAggregationBatchSize must be greater than 1 if MessageAggregation is true and MessageAggregationBatchSize is set") } return nil } diff --git a/relay/prover.go b/relay/prover.go index 4630851..fa897e0 100644 --- a/relay/prover.go +++ b/relay/prover.go @@ -217,7 +217,7 @@ func (pr *Prover) aggregateMessages(messages [][]byte, signatures [][]byte, sign return nil, fmt.Errorf("aggregateMessages: messages and signatures must have the same length: messages=%v signatures=%v", len(messages), len(signatures)) } for { - batches, err := splitIntoMultiBatch(messages, signatures, signer, pr.config.MessageAggregationBatchSize) + batches, err := splitIntoMultiBatch(messages, signatures, signer, pr.config.GetMessageAggregationBatchSize()) if err != nil { return nil, err } @@ -271,6 +271,9 @@ func splitIntoMultiBatch(messages [][]byte, signatures [][]byte, signer []byte, var res []*elc.MsgAggregateMessages var currentMessages [][]byte var currentBatchStartIndex uint64 = 0 + if messageBatchSize < 2 { + return nil, fmt.Errorf("messageBatchSize must be greater than 1") + } for i := 0; i < len(messages); i++ { currentMessages = append(currentMessages, messages[i]) if uint64(len(currentMessages)) == messageBatchSize {