diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index 87f4a6b9a09..4307c976acc 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -319,10 +319,10 @@ CryptoOpcodesV2EnableEpoch = 4 # EquivalentMessagesEnableEpoch represents the epoch when the equivalent messages are enabled - EquivalentMessagesEnableEpoch = 4 + EquivalentMessagesEnableEpoch = 8 # the chain simulator tests for staking v4 fail if this is set earlier, as they test the transition in epochs 4-7 # FixedOrderInConsensusEnableEpoch represents the epoch when the fixed order in consensus is enabled - FixedOrderInConsensusEnableEpoch = 4 + FixedOrderInConsensusEnableEpoch = 8 # the chain simulator tests for staking v4 fail if this is set earlier, as they test the transition in epochs 4-7 # BLSMultiSignerEnableEpoch represents the activation epoch for different types of BLS multi-signers BLSMultiSignerEnableEpoch = [ diff --git a/common/common.go b/common/common.go index d5624d7777a..3d5e874c231 100644 --- a/common/common.go +++ b/common/common.go @@ -1,8 +1,13 @@ package common import ( + "fmt" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-vm-v1_2-go/ipc/marshaling" ) // IsEpochChangeBlockForFlagActivation returns true if the provided header is the first one after the specified flag's activation @@ -19,3 +24,55 @@ func IsFlagEnabledAfterEpochsStartBlock(header data.HeaderHandler, enableEpochsH isEpochStartBlock := IsEpochChangeBlockForFlagActivation(header, enableEpochsHandler, flag) return isFlagEnabled && !isEpochStartBlock } + +// ShouldBlockHavePrevProof returns true if the block should have a proof +func ShouldBlockHavePrevProof(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { + return IsFlagEnabledAfterEpochsStartBlock(header, enableEpochsHandler, flag) && header.GetNonce() > 1 +} + +// VerifyProofAgainstHeader verifies the fields on the proof match the ones on the header +func VerifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderHandler) error { + if check.IfNilReflect(proof) { + return ErrInvalidHeaderProof + } + + if proof.GetHeaderNonce() != header.GetNonce() { + return fmt.Errorf("%w, nonce mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderShardId() != header.GetShardID() { + return fmt.Errorf("%w, shard id mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderEpoch() != header.GetEpoch() { + return fmt.Errorf("%w, epoch mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderRound() != header.GetRound() { + return fmt.Errorf("%w, round mismatch", ErrInvalidHeaderProof) + } + + return nil +} + +// GetHeader tries to get the header from pool first and if not found, searches for it through storer +func GetHeader( + headerHash []byte, + headersPool HeadersPool, + headersStorer storage.Storer, + marshaller marshaling.Marshalizer, +) (data.HeaderHandler, error) { + header, err := headersPool.GetHeaderByHash(headerHash) + if err == nil { + return header, nil + } + + headerBytes, err := headersStorer.SearchFirst(headerHash) + if err != nil { + return nil, err + } + + err = marshaller.Unmarshal(header, headerBytes) + if err != nil { + return nil, err + } + + return header, nil +} diff --git a/common/errors.go b/common/errors.go index 47b976de9a8..eeeaf94c804 100644 --- a/common/errors.go +++ b/common/errors.go @@ -10,3 +10,6 @@ var ErrNilWasmChangeLocker = errors.New("nil wasm change locker") // ErrNilStateSyncNotifierSubscriber signals that a nil state sync notifier subscriber has been provided var ErrNilStateSyncNotifierSubscriber = errors.New("nil state sync notifier subscriber") + +// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided +var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") diff --git a/common/interface.go b/common/interface.go index 72a2cba2628..e3a42e0ee45 100644 --- a/common/interface.go +++ b/common/interface.go @@ -379,3 +379,8 @@ type ChainParametersSubscriptionHandler interface { ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) IsInterfaceNil() bool } + +// HeadersPool defines what a headers pool structure can perform +type HeadersPool interface { + GetHeaderByHash(hash []byte) (data.HeaderHandler, error) +} diff --git a/consensus/spos/bls/proxy/subroundsHandler.go b/consensus/spos/bls/proxy/subroundsHandler.go index 2b284db5144..79cca8e04bb 100644 --- a/consensus/spos/bls/proxy/subroundsHandler.go +++ b/consensus/spos/bls/proxy/subroundsHandler.go @@ -3,7 +3,6 @@ package proxy import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-core-go/data" logger "github.com/multiversx/mx-chain-logger-go" "github.com/multiversx/mx-chain-go/common" @@ -57,6 +56,14 @@ type SubroundsHandler struct { currentConsensusType consensusStateMachineType } +// EpochConfirmed is called when the epoch is confirmed (this is registered as callback) +func (s *SubroundsHandler) EpochConfirmed(epoch uint32, _ uint64) { + err := s.initSubroundsForEpoch(epoch) + if err != nil { + log.Error("SubroundsHandler.EpochConfirmed: cannot initialize subrounds", "error", err) + } +} + const ( consensusNone consensusStateMachineType = iota consensusV1 @@ -85,7 +92,7 @@ func NewSubroundsHandler(args *SubroundsHandlerArgs) (*SubroundsHandler, error) currentConsensusType: consensusNone, } - subroundHandler.consensusCoreHandler.EpochStartRegistrationHandler().RegisterHandler(subroundHandler) + subroundHandler.consensusCoreHandler.EpochNotifier().RegisterNotifyHandler(subroundHandler) return subroundHandler, nil } @@ -189,28 +196,6 @@ func (s *SubroundsHandler) initSubroundsForEpoch(epoch uint32) error { return nil } -// EpochStartAction is called when the epoch starts -func (s *SubroundsHandler) EpochStartAction(hdr data.HeaderHandler) { - if check.IfNil(hdr) { - log.Error("SubroundsHandler.EpochStartAction: nil header") - return - } - - err := s.initSubroundsForEpoch(hdr.GetEpoch()) - if err != nil { - log.Error("SubroundsHandler.EpochStartAction: cannot initialize subrounds", "error", err) - } -} - -// EpochStartPrepare prepares the subrounds handler for the epoch start -func (s *SubroundsHandler) EpochStartPrepare(_ data.HeaderHandler, _ data.BodyHandler) { -} - -// NotifyOrder returns the order of the subrounds handler -func (s *SubroundsHandler) NotifyOrder() uint32 { - return common.ConsensusHandlerOrder -} - // IsInterfaceNil returns true if there is no value under the interface func (s *SubroundsHandler) IsInterfaceNil() bool { return s == nil diff --git a/consensus/spos/bls/proxy/subroundsHandler_test.go b/consensus/spos/bls/proxy/subroundsHandler_test.go index 403dc2c7826..68e72dba30f 100644 --- a/consensus/spos/bls/proxy/subroundsHandler_test.go +++ b/consensus/spos/bls/proxy/subroundsHandler_test.go @@ -8,7 +8,6 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/stretchr/testify/require" - chainCommon "github.com/multiversx/mx-chain-go/common" mock2 "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" @@ -17,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" mock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" outportStub "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -29,6 +29,7 @@ func getDefaultArgumentsSubroundHandler() (*SubroundsHandlerArgs, *consensus.Con epochsEnable := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} epochStartNotifier := &mock.EpochStartNotifierStub{} consensusState := &consensus.ConsensusStateMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} worker := &consensus.SposWorkerMock{ RemoveAllReceivedMessagesCallsCalled: func() {}, GetConsensusStateChangedChannelsCalled: func() chan bool { @@ -78,6 +79,7 @@ func getDefaultArgumentsSubroundHandler() (*SubroundsHandlerArgs, *consensus.Con consensusCore.SetSigningHandler(&consensus.SigningHandlerStub{}) consensusCore.SetEnableEpochsHandler(epochsEnable) consensusCore.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{}) + consensusCore.SetEpochNotifier(epochNotifier) handlerArgs.ConsensusCoreHandler = consensusCore return handlerArgs, consensusCore @@ -221,12 +223,14 @@ func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusNone err = sh.initSubroundsForEpoch(0) require.Nil(t, err) require.Equal(t, consensusV1, sh.currentConsensusType) - require.Equal(t, int32(1), startCalled.Load()) + require.Equal(t, int32(2), startCalled.Load()) }) t.Run("equivalent messages not enabled, with previous consensus type consensusV1", func(t *testing.T) { t.Parallel() @@ -251,12 +255,15 @@ func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusV1 err = sh.initSubroundsForEpoch(0) require.Nil(t, err) require.Equal(t, consensusV1, sh.currentConsensusType) - require.Equal(t, int32(0), startCalled.Load()) + require.Equal(t, int32(1), startCalled.Load()) + }) t.Run("equivalent messages enabled, with previous consensus type consensusNone", func(t *testing.T) { t.Parallel() @@ -280,12 +287,14 @@ func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusNone err = sh.initSubroundsForEpoch(0) require.Nil(t, err) require.Equal(t, consensusV2, sh.currentConsensusType) - require.Equal(t, int32(1), startCalled.Load()) + require.Equal(t, int32(2), startCalled.Load()) }) t.Run("equivalent messages enabled, with previous consensus type consensusV1", func(t *testing.T) { t.Parallel() @@ -309,12 +318,14 @@ func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusV1 err = sh.initSubroundsForEpoch(0) require.Nil(t, err) require.Equal(t, consensusV2, sh.currentConsensusType) - require.Equal(t, int32(1), startCalled.Load()) + require.Equal(t, int32(2), startCalled.Load()) }) t.Run("equivalent messages enabled, with previous consensus type consensusV2", func(t *testing.T) { t.Parallel() @@ -339,12 +350,14 @@ func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusV2 err = sh.initSubroundsForEpoch(0) require.Nil(t, err) require.Equal(t, consensusV2, sh.currentConsensusType) - require.Equal(t, int32(0), startCalled.Load()) + require.Equal(t, int32(1), startCalled.Load()) }) } @@ -375,27 +388,17 @@ func TestSubroundsHandler_Start(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on init of EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusNone err = sh.Start(0) require.Nil(t, err) require.Equal(t, consensusV1, sh.currentConsensusType) - require.Equal(t, int32(1), startCalled.Load()) + require.Equal(t, int32(2), startCalled.Load()) }) } -func TestSubroundsHandler_NotifyOrder(t *testing.T) { - t.Parallel() - - handlerArgs, _ := getDefaultArgumentsSubroundHandler() - sh, err := NewSubroundsHandler(handlerArgs) - require.Nil(t, err) - require.NotNil(t, sh) - - order := sh.NotifyOrder() - require.Equal(t, uint32(chainCommon.ConsensusHandlerOrder), order) -} - func TestSubroundsHandler_IsInterfaceNil(t *testing.T) { t.Parallel() @@ -417,7 +420,7 @@ func TestSubroundsHandler_IsInterfaceNil(t *testing.T) { }) } -func TestSubroundsHandler_EpochStartAction(t *testing.T) { +func TestSubroundsHandler_EpochConfirmed(t *testing.T) { t.Parallel() t.Run("nil handler does not panic", func(t *testing.T) { @@ -431,7 +434,7 @@ func TestSubroundsHandler_EpochStartAction(t *testing.T) { handlerArgs, _ := getDefaultArgumentsSubroundHandler() sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) - sh.EpochStartAction(&testscommon.HeaderHandlerStub{}) + sh.EpochConfirmed(0, 0) }) // tested through initSubroundsForEpoch @@ -458,11 +461,13 @@ func TestSubroundsHandler_EpochStartAction(t *testing.T) { sh, err := NewSubroundsHandler(handlerArgs) require.Nil(t, err) require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) sh.currentConsensusType = consensusNone - sh.EpochStartAction(&testscommon.HeaderHandlerStub{}) + sh.EpochConfirmed(0, 0) require.Nil(t, err) require.Equal(t, consensusV1, sh.currentConsensusType) - require.Equal(t, int32(1), startCalled.Load()) + require.Equal(t, int32(2), startCalled.Load()) }) } diff --git a/consensus/spos/bls/v1/blsSubroundsFactory.go b/consensus/spos/bls/v1/blsSubroundsFactory.go index 70915c5f30b..99b8e9260d4 100644 --- a/consensus/spos/bls/v1/blsSubroundsFactory.go +++ b/consensus/spos/bls/v1/blsSubroundsFactory.go @@ -104,6 +104,7 @@ func (fct *factory) GenerateSubrounds() error { fct.initConsensusThreshold() fct.consensusCore.Chronology().RemoveAllSubrounds() fct.worker.RemoveAllReceivedMessagesCalls() + fct.worker.RemoveAllReceivedHeaderHandlers() err := fct.generateStartRoundSubround() if err != nil { @@ -206,6 +207,7 @@ func (fct *factory) generateBlockSubround() error { fct.worker.AddReceivedMessageCall(bls.MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) fct.worker.AddReceivedMessageCall(bls.MtBlockBody, subroundBlockInstance.receivedBlockBody) fct.worker.AddReceivedMessageCall(bls.MtBlockHeader, subroundBlockInstance.receivedBlockHeader) + fct.worker.AddReceivedHeaderHandler(subroundBlockInstance.receivedFullHeader) fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) return nil diff --git a/consensus/spos/bls/v1/errors.go b/consensus/spos/bls/v1/errors.go index 05c55b9592c..55a822b53dd 100644 --- a/consensus/spos/bls/v1/errors.go +++ b/consensus/spos/bls/v1/errors.go @@ -4,3 +4,6 @@ import "errors" // ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrEquivalentMessagesFlagEnabledWithConsensusV1 defines the error for running with the equivalent messages flag enabled under v1 consensus +var ErrEquivalentMessagesFlagEnabledWithConsensusV1 = errors.New("equivalent messages flag enabled with consensus v1") diff --git a/consensus/spos/bls/v1/subroundBlock.go b/consensus/spos/bls/v1/subroundBlock.go index 504cb82a180..4d649c0a602 100644 --- a/consensus/spos/bls/v1/subroundBlock.go +++ b/consensus/spos/bls/v1/subroundBlock.go @@ -1,6 +1,7 @@ package v1 import ( + "bytes" "context" "time" @@ -334,6 +335,10 @@ func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { return nil, err } + if sr.EnableEpochsHandler().IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.GetEpoch()) { + return nil, ErrEquivalentMessagesFlagEnabledWithConsensusV1 + } + err = hdr.SetPrevHash(prevHash) if err != nil { return nil, err @@ -491,6 +496,27 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return blockProcessedWithSuccess } +func (sr *subroundBlock) receivedFullHeader(headerHandler data.HeaderHandler) { + if sr.ShardCoordinator().SelfId() != headerHandler.GetShardID() { + log.Debug("subroundBlock.ReceivedFullHeader early exit", "headerShardID", headerHandler.GetShardID(), "selfShardID", sr.ShardCoordinator().SelfId()) + return + } + + if !sr.EnableEpochsHandler().IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerHandler.GetEpoch()) { + log.Debug("subroundBlock.ReceivedFullHeader early exit", "flagNotEnabled in header epoch", headerHandler.GetEpoch()) + return + } + + log.Debug("subroundBlock.ReceivedFullHeader", "nonce", headerHandler.GetNonce(), "epoch", headerHandler.GetEpoch()) + + lastCommittedBlockHash := sr.Blockchain().GetCurrentBlockHeaderHash() + if bytes.Equal(lastCommittedBlockHash, headerHandler.GetPrevHash()) { + // Need to switch to consensus v2 + log.Debug("subroundBlock.ReceivedFullHeader switching epoch") + go sr.EpochNotifier().CheckEpoch(headerHandler) + } +} + // receivedBlockHeader method is called when a block header is received through the block header channel. // If the block header is valid, then the validatorRoundStates map corresponding to the node which sent it, // is set on true for the subround Block diff --git a/consensus/spos/bls/v2/blsSubroundsFactory.go b/consensus/spos/bls/v2/blsSubroundsFactory.go index 2c9ade325a0..53e2b608019 100644 --- a/consensus/spos/bls/v2/blsSubroundsFactory.go +++ b/consensus/spos/bls/v2/blsSubroundsFactory.go @@ -112,6 +112,7 @@ func (fct *factory) GenerateSubrounds() error { fct.initConsensusThreshold() fct.consensusCore.Chronology().RemoveAllSubrounds() fct.worker.RemoveAllReceivedMessagesCalls() + fct.worker.RemoveAllReceivedHeaderHandlers() err := fct.generateStartRoundSubround() if err != nil { diff --git a/consensus/spos/bls/v2/subroundBlock.go b/consensus/spos/bls/v2/subroundBlock.go index 2454ad3643e..27bf49a7af7 100644 --- a/consensus/spos/bls/v2/subroundBlock.go +++ b/consensus/spos/bls/v2/subroundBlock.go @@ -3,7 +3,6 @@ package v2 import ( "bytes" "context" - "encoding/hex" "time" "github.com/multiversx/mx-chain-core-go/core" @@ -105,12 +104,6 @@ func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { return false } - // This must be done after createBlock, in order to have the proper epoch set - wasProofAdded := sr.addProofOnHeader(header) - if !wasProofAdded { - return false - } - // block proof verification should be done over the header that contains the leader signature leaderSignature, err := sr.signBlockHeader(header) if err != nil { @@ -360,36 +353,7 @@ func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { return hdr, nil } -func (sr *subroundBlock) addProofOnHeader(header data.HeaderHandler) bool { - prevBlockProof, err := sr.EquivalentProofsPool().GetProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash()) - if err != nil { - // for the first block after activation we won't add the proof - // TODO: fix this on verifications as well - return common.IsEpochChangeBlockForFlagActivation(header, sr.EnableEpochsHandler(), common.EquivalentMessagesFlag) - } - - if !isProofEmpty(prevBlockProof) { - header.SetPreviousProof(prevBlockProof) - return true - } - - hash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), header) - if err != nil { - hash = []byte("") - } - - log.Debug("addProofOnHeader: no proof found", "header hash", hex.EncodeToString(hash)) - - return false -} - -func isProofEmpty(proof data.HeaderProofHandler) bool { - return len(proof.GetAggregatedSignature()) == 0 || - len(proof.GetPubKeysBitmap()) == 0 || - len(proof.GetHeaderHash()) == 0 -} - -func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler) { +func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler, prevHeader data.HeaderHandler) { hasProof := sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash()) if hasProof { log.Debug("saveProofForPreviousHeaderIfNeeded: no need to set proof since it is already saved") @@ -397,11 +361,16 @@ func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHa } proof := header.GetPreviousProof() - err := sr.EquivalentProofsPool().AddProof(proof) + err := common.VerifyProofAgainstHeader(proof, prevHeader) if err != nil { - log.Debug("saveProofForPreviousHeaderIfNeeded: failed to add proof, %w", err) + log.Debug("saveProofForPreviousHeaderIfNeeded: invalid proof, %w", err) return } + + err = sr.EquivalentProofsPool().AddProof(proof) + if err != nil { + log.Debug("saveProofForPreviousHeaderIfNeeded: failed to add proof, %w", err) + } } // receivedBlockBody method is called when a block body is received through the block body channel @@ -445,39 +414,47 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return blockProcessedWithSuccess } -func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) bool { +func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) (bool, data.HeaderHandler) { if check.IfNil(header) { - return false + return false, nil } if header.GetShardID() != sr.ShardCoordinator().SelfId() { - return false + return false, nil } if header.GetRound() != uint64(sr.RoundHandler().Index()) { - return false + return false, nil } prevHeader, prevHash := sr.getPrevHeaderAndHash() if check.IfNil(prevHeader) { - return false + return false, nil } if !bytes.Equal(header.GetPrevHash(), prevHash) { - return false + return false, nil } if header.GetNonce() != prevHeader.GetNonce()+1 { - return false + return false, nil } prevRandSeed := prevHeader.GetRandSeed() - return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed) + return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed), prevHeader } func (sr *subroundBlock) getLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) { nc := sr.NodesCoordinator() + + prevBlockEpoch := sr.Blockchain().GetCurrentBlockHeader().GetEpoch() + // TODO: remove this if first block in new epoch will be validated by epoch validators + // first block in epoch is validated by previous epoch validators + selectionEpoch := headerHandler.GetEpoch() + if selectionEpoch != prevBlockEpoch { + selectionEpoch = prevBlockEpoch + } leader, _, err := nc.ComputeConsensusGroup( headerHandler.GetPrevRandSeed(), headerHandler.GetRound(), headerHandler.GetShardID(), - headerHandler.GetEpoch(), + selectionEpoch, ) if err != nil { return nil, err @@ -491,25 +468,31 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { return } + log.Debug("subroundBlock.receivedBlockHeader", "nonce", headerHandler.GetNonce(), "round", headerHandler.GetRound()) if headerHandler.CheckFieldsForNil() != nil { return } - if !sr.isHeaderForCurrentConsensus(headerHandler) { + isHeaderForCurrentConsensus, prevHeader := sr.isHeaderForCurrentConsensus(headerHandler) + if !isHeaderForCurrentConsensus { + log.Debug("subroundBlock.receivedBlockHeader - header is not for current consensus") return } isLeader := sr.IsSelfLeader() if sr.ConsensusGroup() == nil || isLeader { + log.Debug("subroundBlock.receivedBlockHeader - consensus group is nil or is leader") return } if sr.IsConsensusDataSet() { + log.Debug("subroundBlock.receivedBlockHeader - consensus data is set") return } headerLeader, err := sr.getLeaderForHeader(headerHandler) if err != nil { + log.Debug("subroundBlock.receivedBlockHeader - error getting leader for header", err.Error()) return } @@ -520,26 +503,30 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { spos.LeaderPeerHonestyDecreaseFactor, ) + log.Debug("subroundBlock.receivedBlockHeader - leader is not the leader in current round") return } if sr.IsHeaderAlreadyReceived() { + log.Debug("subroundBlock.receivedBlockHeader - header is already received") return } if !sr.CanProcessReceivedHeader(string(headerLeader)) { + log.Debug("subroundBlock.receivedBlockHeader - can not process received header") return } marshalledHeader, err := sr.Marshalizer().Marshal(headerHandler) if err != nil { + log.Debug("subroundBlock.receivedBlockHeader", "error", err.Error()) return } sr.SetData(sr.Hasher().Compute(string(marshalledHeader))) sr.SetHeader(headerHandler) - sr.saveProofForPreviousHeaderIfNeeded(headerHandler) + sr.saveProofForPreviousHeaderIfNeeded(headerHandler, prevHeader) log.Debug("step 1: block header has been received", "nonce", sr.GetHeader().GetNonce(), diff --git a/consensus/spos/bls/v2/subroundBlock_test.go b/consensus/spos/bls/v2/subroundBlock_test.go index d22d5e2f1ca..56fb0e4f83f 100644 --- a/consensus/spos/bls/v2/subroundBlock_test.go +++ b/consensus/spos/bls/v2/subroundBlock_test.go @@ -553,10 +553,6 @@ func TestSubroundBlock_DoBlockJob(t *testing.T) { r := sr.DoBlockJob() assert.True(t, r) assert.Equal(t, uint64(1), sr.GetHeader().GetNonce()) - - proof := sr.GetHeader().GetPreviousProof() - assert.Equal(t, providedSignature, proof.GetAggregatedSignature()) - assert.Equal(t, providedBitmap, proof.GetPubKeysBitmap()) }) } diff --git a/consensus/spos/bls/v2/subroundEndRound.go b/consensus/spos/bls/v2/subroundEndRound.go index b5e6440685f..b91001bb2a7 100644 --- a/consensus/spos/bls/v2/subroundEndRound.go +++ b/consensus/spos/bls/v2/subroundEndRound.go @@ -264,7 +264,6 @@ func (sr *subroundEndRound) doEndRoundJobByNode() bool { err = sr.EquivalentProofsPool().AddProof(proof) if err != nil { log.Debug("doEndRoundJobByNode.AddProof", "error", err) - return false } } @@ -533,6 +532,7 @@ func (sr *subroundEndRound) createAndBroadcastProof(signature []byte, bitmap []b HeaderEpoch: sr.GetHeader().GetEpoch(), HeaderNonce: sr.GetHeader().GetNonce(), HeaderShardId: sr.GetHeader().GetShardID(), + HeaderRound: sr.GetHeader().GetRound(), } err := sr.BroadcastMessenger().BroadcastEquivalentProof(headerProof, []byte(sr.SelfPubKey())) diff --git a/consensus/spos/consensusCore.go b/consensus/spos/consensusCore.go index 1f263a0af9d..39dfa3e1c66 100644 --- a/consensus/spos/consensusCore.go +++ b/consensus/spos/consensusCore.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" @@ -41,6 +42,7 @@ type ConsensusCore struct { signingHandler consensus.SigningHandler enableEpochsHandler common.EnableEpochsHandler equivalentProofsPool consensus.EquivalentProofsPool + epochNotifier process.EpochNotifier } // ConsensusCoreArgs store all arguments that are needed to create a ConsensusCore object @@ -69,6 +71,7 @@ type ConsensusCoreArgs struct { SigningHandler consensus.SigningHandler EnableEpochsHandler common.EnableEpochsHandler EquivalentProofsPool consensus.EquivalentProofsPool + EpochNotifier process.EpochNotifier } // NewConsensusCore creates a new ConsensusCore instance @@ -100,6 +103,7 @@ func NewConsensusCore( signingHandler: args.SigningHandler, enableEpochsHandler: args.EnableEpochsHandler, equivalentProofsPool: args.EquivalentProofsPool, + epochNotifier: args.EpochNotifier, } err := ValidateConsensusCore(consensusCore) @@ -180,6 +184,11 @@ func (cc *ConsensusCore) EpochStartRegistrationHandler() epochStart.Registration return cc.epochStartRegistrationHandler } +// EpochNotifier returns the epoch notifier +func (cc *ConsensusCore) EpochNotifier() process.EpochNotifier { + return cc.epochNotifier +} + // PeerHonestyHandler will return the peer honesty handler which will be used in subrounds func (cc *ConsensusCore) PeerHonestyHandler() consensus.PeerHonestyHandler { return cc.peerHonestyHandler diff --git a/consensus/spos/consensusCoreValidator.go b/consensus/spos/consensusCoreValidator.go index 0eee3039007..c5905d10fd1 100644 --- a/consensus/spos/consensusCoreValidator.go +++ b/consensus/spos/consensusCoreValidator.go @@ -80,6 +80,9 @@ func ValidateConsensusCore(container ConsensusCoreHandler) error { if check.IfNil(container.EquivalentProofsPool()) { return ErrNilEquivalentProofPool } + if check.IfNil(container.EpochNotifier()) { + return ErrNilEpochNotifier + } return nil } diff --git a/consensus/spos/consensusCoreValidator_test.go b/consensus/spos/consensusCoreValidator_test.go index 5594b831311..47c5a66ab78 100644 --- a/consensus/spos/consensusCoreValidator_test.go +++ b/consensus/spos/consensusCoreValidator_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" ) @@ -41,6 +42,7 @@ func initConsensusDataContainer() *ConsensusCore { signingHandler := &consensusMocks.SigningHandlerStub{} enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} proofsPool := &dataRetriever.ProofsPoolMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} return &ConsensusCore{ blockChain: blockChain, @@ -66,6 +68,7 @@ func initConsensusDataContainer() *ConsensusCore { signingHandler: signingHandler, enableEpochsHandler: enableEpochsHandler, equivalentProofsPool: proofsPool, + epochNotifier: epochNotifier, } } diff --git a/consensus/spos/consensusCore_test.go b/consensus/spos/consensusCore_test.go index ef860956152..70e372266b8 100644 --- a/consensus/spos/consensusCore_test.go +++ b/consensus/spos/consensusCore_test.go @@ -40,6 +40,7 @@ func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { SigningHandler: consensusCoreMock.SigningHandler(), EnableEpochsHandler: consensusCoreMock.EnableEpochsHandler(), EquivalentProofsPool: consensusCoreMock.EquivalentProofsPool(), + EpochNotifier: consensusCoreMock.EpochNotifier(), } return args } diff --git a/consensus/spos/errors.go b/consensus/spos/errors.go index 62f9c23ad17..9aa69060fed 100644 --- a/consensus/spos/errors.go +++ b/consensus/spos/errors.go @@ -270,3 +270,6 @@ var ErrHeaderProofNotExpected = errors.New("header proof not expected") // ErrConsensusMessageNotExpected signals that a consensus message was not expected var ErrConsensusMessageNotExpected = errors.New("consensus message not expected") + +// ErrNilEpochNotifier signals that a nil epoch notifier has been provided +var ErrNilEpochNotifier = errors.New("nil epoch notifier") diff --git a/consensus/spos/interface.go b/consensus/spos/interface.go index d85c94f2b7a..071790979c9 100644 --- a/consensus/spos/interface.go +++ b/consensus/spos/interface.go @@ -47,6 +47,7 @@ type ConsensusCoreHandler interface { SigningHandler() consensus.SigningHandler EnableEpochsHandler() common.EnableEpochsHandler EquivalentProofsPool() consensus.EquivalentProofsPool + EpochNotifier() process.EpochNotifier IsInterfaceNil() bool } @@ -104,6 +105,8 @@ type WorkerHandler interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // RemoveAllReceivedHeaderHandlers removes all the functions handlers + RemoveAllReceivedHeaderHandlers() // AddReceivedProofHandler adds a new handler function for a received proof AddReceivedProofHandler(handler func(consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index e539071331d..804bec83715 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -310,6 +310,13 @@ func (wrk *Worker) AddReceivedHeaderHandler(handler func(data.HeaderHandler)) { wrk.mutReceivedHeadersHandler.Unlock() } +// RemoveAllReceivedHeaderHandlers removes all the functions handlers +func (wrk *Worker) RemoveAllReceivedHeaderHandlers() { + wrk.mutReceivedHeadersHandler.Lock() + wrk.receivedHeadersHandlers = make([]func(data.HeaderHandler), 0) + wrk.mutReceivedHeadersHandler.Unlock() +} + // ReceivedProof process the received proof, calling each received proof handler registered in worker instance func (wrk *Worker) ReceivedProof(proofHandler consensus.ProofHandler) { if check.IfNilReflect(proofHandler) { diff --git a/dataRetriever/dataPool/headersCache/headersPool.go b/dataRetriever/dataPool/headersCache/headersPool.go index cf824cc6e10..8b2e044b432 100644 --- a/dataRetriever/dataPool/headersCache/headersPool.go +++ b/dataRetriever/dataPool/headersCache/headersPool.go @@ -5,9 +5,10 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("dataRetriever/headersCache") @@ -64,6 +65,7 @@ func (pool *headersPool) AddHeader(headerHash []byte, header data.HeaderHandler) added := pool.cache.addHeader(headerHash, header) if added { + log.Debug("TOREMOVE - added header to pool", "cache ptr", fmt.Sprintf("%p", pool.cache), "header shard", header.GetShardID(), "header nonce", header.GetNonce()) pool.callAddedDataHandlers(header, headerHash) } } diff --git a/dataRetriever/interface.go b/dataRetriever/interface.go index ade580bd985..c28843491d0 100644 --- a/dataRetriever/interface.go +++ b/dataRetriever/interface.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/counting" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" @@ -362,6 +363,7 @@ type PeerAuthenticationPayloadValidator interface { // ProofsPool defines the behaviour of a proofs pool components type ProofsPool interface { AddProof(headerProof data.HeaderProofHandler) error + RegisterHandler(handler func(headerProof data.HeaderProofHandler)) CleanupProofsBehindNonce(shardID uint32, nonce uint64) error GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) HasProof(shardID uint32, headerHash []byte) bool diff --git a/epochStart/errors.go b/epochStart/errors.go index e022064c472..8ef800aebfb 100644 --- a/epochStart/errors.go +++ b/epochStart/errors.go @@ -74,6 +74,9 @@ var ErrNilRequestHandler = errors.New("nil request handler") // ErrNilMetaBlocksPool signals that nil metablock pools holder has been provided var ErrNilMetaBlocksPool = errors.New("nil metablocks pool") +// ErrNilProofsPool signals that nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") + // ErrNilValidatorInfoProcessor signals that a nil validator info processor has been provided var ErrNilValidatorInfoProcessor = errors.New("nil validator info processor") diff --git a/epochStart/interface.go b/epochStart/interface.go index a5cb5881ddf..6f05bc409cf 100644 --- a/epochStart/interface.go +++ b/epochStart/interface.go @@ -7,9 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // TriggerHandler defines the functionalities for an start of epoch trigger diff --git a/epochStart/metachain/trigger.go b/epochStart/metachain/trigger.go index 5dedb1f1cda..d503699bb7f 100644 --- a/epochStart/metachain/trigger.go +++ b/epochStart/metachain/trigger.go @@ -15,13 +15,14 @@ import ( "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("epochStart/metachain") diff --git a/epochStart/notifier/common.go b/epochStart/notifier/common.go index b535bd54589..bf36da4b45e 100644 --- a/epochStart/notifier/common.go +++ b/epochStart/notifier/common.go @@ -2,6 +2,7 @@ package notifier import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/notifier/epochStartSubscriptionHandler.go b/epochStart/notifier/epochStartSubscriptionHandler.go index 1e4141a96dd..3d2041189ce 100644 --- a/epochStart/notifier/epochStartSubscriptionHandler.go +++ b/epochStart/notifier/epochStartSubscriptionHandler.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/shardchain/trigger.go b/epochStart/shardchain/trigger.go index 496702b8d81..f2166f7d5b9 100644 --- a/epochStart/shardchain/trigger.go +++ b/epochStart/shardchain/trigger.go @@ -19,12 +19,13 @@ import ( "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("epochStart/shardchain") @@ -79,6 +80,7 @@ type trigger struct { mapFinalizedEpochs map[uint32]string headersPool dataRetriever.HeadersPool + proofsPool dataRetriever.ProofsPool miniBlocksPool storage.Cacher validatorInfoPool dataRetriever.ShardedDataCacherNotifier currentEpochValidatorInfoPool epochStart.ValidatorInfoCacher @@ -170,6 +172,9 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { if check.IfNil(args.DataPool.Headers()) { return nil, epochStart.ErrNilMetaBlocksPool } + if check.IfNil(args.DataPool.Proofs()) { + return nil, epochStart.ErrNilProofsPool + } if check.IfNil(args.DataPool.MiniBlocks()) { return nil, epochStart.ErrNilMiniBlockPool } @@ -247,6 +252,7 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { mapEpochStartHdrs: make(map[string]data.HeaderHandler), mapFinalizedEpochs: make(map[uint32]string), headersPool: args.DataPool.Headers(), + proofsPool: args.DataPool.Proofs(), miniBlocksPool: args.DataPool.MiniBlocks(), validatorInfoPool: args.DataPool.ValidatorsInfo(), currentEpochValidatorInfoPool: args.DataPool.CurrentEpochValidatorInfo(), @@ -271,6 +277,7 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { } t.headersPool.RegisterHandler(t.receivedMetaBlock) + t.proofsPool.RegisterHandler(t.receivedProof) err = t.saveState(t.triggerStateKey) if err != nil { @@ -555,12 +562,51 @@ func (t *trigger) changeEpochFinalityAttestingRoundIfNeeded( t.epochFinalityAttestingRound = metaHdr.GetRound() } +func (t *trigger) receivedProof(headerProof data.HeaderProofHandler) { + if check.IfNilReflect(headerProof) { + return + } + if headerProof.GetHeaderShardId() != core.MetachainShardId { + return + } + t.mutTrigger.Lock() + defer t.mutTrigger.Unlock() + + header, err := t.headersPool.GetHeaderByHash(headerProof.GetHeaderHash()) + if err != nil { + return + } + + t.checkMetaHeaderForEpochTriggerEquivalentProofs(header, headerProof.GetHeaderHash()) +} + // receivedMetaBlock is a callback function when a new metablock was received // upon receiving checks if trigger can be updated func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockHash []byte) { + if t.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerHandler.GetEpoch()) { + return + } + t.mutTrigger.Lock() defer t.mutTrigger.Unlock() + t.checkMetaHeaderForEpochTriggerLegacy(headerHandler, metaBlockHash) +} + +func (t *trigger) checkMetaHeaderForEpochTriggerEquivalentProofs(headerHandler data.HeaderHandler, metaBlockHash []byte) { + metaHdr, ok := headerHandler.(*block.MetaBlock) + if !ok { + return + } + if !t.shouldUpdateTrigger(metaHdr, metaBlockHash) { + return + } + + t.updateTriggerHeaderData(metaHdr, metaBlockHash) + t.updateTriggerFromMeta() +} + +func (t *trigger) checkMetaHeaderForEpochTriggerLegacy(headerHandler data.HeaderHandler, metaBlockHash []byte) { metaHdr, ok := headerHandler.(*block.MetaBlock) if !ok { return @@ -574,28 +620,41 @@ func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockH } } - if !t.newEpochHdrReceived && !metaHdr.IsStartOfEpochBlock() { + if !t.shouldUpdateTrigger(metaHdr, metaBlockHash) { return } + t.updateTriggerHeaderData(metaHdr, metaBlockHash) + t.updateTriggerFromMeta() +} + +func (t *trigger) shouldUpdateTrigger(metaHdr *block.MetaBlock, metaBlockHash []byte) bool { + if !t.newEpochHdrReceived && !metaHdr.IsStartOfEpochBlock() { + return false + } + isMetaStartOfEpochForCurrentEpoch := metaHdr.Epoch == t.epoch && metaHdr.IsStartOfEpochBlock() if isMetaStartOfEpochForCurrentEpoch { - return + return false } - if _, ok = t.mapHashHdr[string(metaBlockHash)]; ok { - return + if _, ok := t.mapHashHdr[string(metaBlockHash)]; ok { + return false } - if _, ok = t.mapEpochStartHdrs[string(metaBlockHash)]; ok { - return + if _, ok := t.mapEpochStartHdrs[string(metaBlockHash)]; ok { + return false } + return true +} + +func (t *trigger) updateTriggerHeaderData(metaHdr *block.MetaBlock, metaBlockHash []byte) { if metaHdr.IsStartOfEpochBlock() { t.newEpochHdrReceived = true t.mapEpochStartHdrs[string(metaBlockHash)] = metaHdr // waiting for late broadcast of mini blocks and transactions to be done and received wait := t.extraDelayForRequestBlockInfo - roundDifferences := t.roundHandler.Index() - int64(headerHandler.GetRound()) + roundDifferences := t.roundHandler.Index() - int64(metaHdr.GetRound()) if roundDifferences > 1 { wait = 0 } @@ -605,8 +664,6 @@ func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockH t.mapHashHdr[string(metaBlockHash)] = metaHdr t.mapNonceHashes[metaHdr.Nonce] = append(t.mapNonceHashes[metaHdr.Nonce], string(metaBlockHash)) - - t.updateTriggerFromMeta() } // call only if mutex is locked before @@ -722,10 +779,24 @@ func (t *trigger) isMetaBlockValid(hash string, metaHdr data.HeaderHandler) bool return true } -func (t *trigger) isMetaBlockFinal(_ string, metaHdr data.HeaderHandler) (bool, uint64) { +func (t *trigger) isMetaBlockFinal(hash string, metaHdr data.HeaderHandler) (bool, uint64) { + if !t.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, metaHdr.GetEpoch()) { + return t.isMetaBlockFinalLegacy(hash, metaHdr) + } + + hasProof := t.proofsPool.HasProof(metaHdr.GetShardID(), []byte(hash)) + if !hasProof { + return false, 0 + } + + return true, metaHdr.GetRound() +} + +func (t *trigger) isMetaBlockFinalLegacy(_ string, metaHdr data.HeaderHandler) (bool, uint64) { nextBlocksVerified := uint64(0) finalityAttestingRound := metaHdr.GetRound() currHdr := metaHdr + for nonce := metaHdr.GetNonce() + 1; nonce <= metaHdr.GetNonce()+t.finality; nonce++ { currHash, err := core.CalculateHash(t.marshaller, t.hasher, currHdr) if err != nil { diff --git a/epochStart/shardchain/triggerRegistry.go b/epochStart/shardchain/triggerRegistry.go index 899e99e83bc..d3f5e8d18c6 100644 --- a/epochStart/shardchain/triggerRegistry.go +++ b/epochStart/shardchain/triggerRegistry.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/shardchain/triggerRegistry_test.go b/epochStart/shardchain/triggerRegistry_test.go index 5adccc849e1..970f48f6a73 100644 --- a/epochStart/shardchain/triggerRegistry_test.go +++ b/epochStart/shardchain/triggerRegistry_test.go @@ -6,13 +6,14 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func cloneTrigger(t *trigger) *trigger { @@ -42,6 +43,7 @@ func cloneTrigger(t *trigger) *trigger { rt.requestHandler = t.requestHandler rt.epochStartNotifier = t.epochStartNotifier rt.headersPool = t.headersPool + rt.proofsPool = t.proofsPool rt.epochStartShardHeader = t.epochStartShardHeader rt.epochStartMeta = t.epochStartMeta rt.shardHdrStorage = t.shardHdrStorage diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index fcb7edc0ad2..15b7ffbda88 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -50,6 +50,9 @@ func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, }, Storage: &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { @@ -390,6 +393,9 @@ func TestTrigger_ReceivedHeaderIsEpochStartTrueWithPeerMiniblocks(t *testing.T) CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } args.Uint64Converter = &mock.Uint64ByteSliceConverterMock{ ToByteSliceCalled: func(u uint64) []byte { @@ -700,6 +706,9 @@ func TestTrigger_UpdateMissingValidatorsInfo(t *testing.T) { }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartTrigger, _ := NewEpochStartTrigger(args) diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index 170638a7268..d1482498819 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -257,6 +257,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { SigningHandler: ccf.cryptoComponents.ConsensusSigningHandler(), EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), EquivalentProofsPool: ccf.dataComponents.Datapool().Proofs(), + EpochNotifier: ccf.coreComponents.EpochNotifier(), } consensusDataContainer, err := spos.NewConsensusCore( diff --git a/factory/interface.go b/factory/interface.go index 85331045ecc..04849acb352 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -386,6 +386,8 @@ type ConsensusWorker interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // RemoveAllReceivedHeaderHandlers removes all the functions handlers + RemoveAllReceivedHeaderHandlers() // AddReceivedProofHandler adds a new handler function for a received proof AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers diff --git a/factory/mock/epochStartNotifierStub.go b/factory/mock/epochStartNotifierStub.go index 128242e1203..7e29fbae327 100644 --- a/factory/mock/epochStartNotifierStub.go +++ b/factory/mock/epochStartNotifierStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/factory/mock/forkDetectorMock.go b/factory/mock/forkDetectorMock.go index 4a041bc814a..eebe8b3efd0 100644 --- a/factory/mock/forkDetectorMock.go +++ b/factory/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,7 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) } // RestoreToGenesis - @@ -111,6 +113,13 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index dd5075d5dfd..9fd1d330d19 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -465,8 +465,9 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { } argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: pcf.coreData.Hasher(), - Marshalizer: pcf.coreData.InternalMarshalizer(), + Hasher: pcf.coreData.Hasher(), + Marshalizer: pcf.coreData.InternalMarshalizer(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } headerValidator, err := block.NewHeaderValidator(argsHeaderValidator) if err != nil { @@ -822,8 +823,9 @@ func (pcf *processComponentsFactory) newEpochStartTrigger(requestHandler epochSt shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: pcf.coreData.Hasher(), - Marshalizer: pcf.coreData.InternalMarshalizer(), + Hasher: pcf.coreData.Hasher(), + Marshalizer: pcf.coreData.InternalMarshalizer(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } headerValidator, err := block.NewHeaderValidator(argsHeaderValidator) if err != nil { @@ -1782,10 +1784,22 @@ func (pcf *processComponentsFactory) newForkDetector( ) (process.ForkDetector, error) { shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { - return sync.NewShardForkDetector(pcf.coreData.RoundHandler(), headerBlackList, blockTracker, pcf.coreData.GenesisNodesSetup().GetStartTime()) + return sync.NewShardForkDetector( + pcf.coreData.RoundHandler(), + headerBlackList, + blockTracker, + pcf.coreData.GenesisNodesSetup().GetStartTime(), + pcf.coreData.EnableEpochsHandler(), + pcf.data.Datapool().Proofs()) } if shardCoordinator.SelfId() == core.MetachainShardId { - return sync.NewMetaForkDetector(pcf.coreData.RoundHandler(), headerBlackList, blockTracker, pcf.coreData.GenesisNodesSetup().GetStartTime()) + return sync.NewMetaForkDetector( + pcf.coreData.RoundHandler(), + headerBlackList, + blockTracker, + pcf.coreData.GenesisNodesSetup().GetStartTime(), + pcf.coreData.EnableEpochsHandler(), + pcf.data.Datapool().Proofs()) } return nil, errors.New("could not create fork detector") diff --git a/go.mod b/go.mod index 895eb3ea982..eb5b2222350 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.5 github.com/mitchellh/mapstructure v1.5.0 github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e - github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030 + github.com/multiversx/mx-chain-core-go v1.2.21-0.20250109123731-7ff31f3e3af6 github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df github.com/multiversx/mx-chain-es-indexer-go v1.7.2-0.20240619122842-05143459c554 github.com/multiversx/mx-chain-logger-go v1.0.15-0.20240508072523-3f00a726af57 diff --git a/go.sum b/go.sum index 7391ce4b459..f30ef4dc815 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUY github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e h1:Tsmwhu+UleE+l3buPuqXSKTqfu5FbPmzQ4MjMoUvCWA= github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e/go.mod h1:2yXl18wUbuV3cRZr7VHxM1xo73kTaC1WUcu2kx8R034= -github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030 h1:4XI4z1ceZC3OUXxTeMQD+6gmTgu9I934nsYlV6P8X4A= -github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= +github.com/multiversx/mx-chain-core-go v1.2.21-0.20250109123731-7ff31f3e3af6 h1:y6qLlkmLp+H2pztSmJDJkf0j9HlpkXvaRd9xjx3J360= +github.com/multiversx/mx-chain-core-go v1.2.21-0.20250109123731-7ff31f3e3af6/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df h1:clihfi78bMEOWk/qw6WA4uQbCM2e2NGliqswLAvw19k= github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df/go.mod h1:gtJYB4rR21KBSqJlazn+2z6f9gFSqQP3KvAgL7Qgxw4= github.com/multiversx/mx-chain-es-indexer-go v1.7.2-0.20240619122842-05143459c554 h1:Fv8BfzJSzdovmoh9Jh/by++0uGsOVBlMP3XiN5Svkn4= diff --git a/integrationTests/chainSimulator/vm/esdtImprovements_test.go b/integrationTests/chainSimulator/vm/esdtImprovements_test.go index 6b1b6690d12..0dab09a20d9 100644 --- a/integrationTests/chainSimulator/vm/esdtImprovements_test.go +++ b/integrationTests/chainSimulator/vm/esdtImprovements_test.go @@ -11,6 +11,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" testsChainSimulator "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" "github.com/multiversx/mx-chain-go/integrationTests/vm/txsFee" @@ -22,8 +25,6 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/require" ) const ( diff --git a/integrationTests/mock/epochStartNotifier.go b/integrationTests/mock/epochStartNotifier.go index c4675a37401..8c6fb4c51e8 100644 --- a/integrationTests/mock/epochStartNotifier.go +++ b/integrationTests/mock/epochStartNotifier.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/integrationTests/mock/forkDetectorStub.go b/integrationTests/mock/forkDetectorStub.go index 950dd2b2e21..4e8751eb263 100644 --- a/integrationTests/mock/forkDetectorStub.go +++ b/integrationTests/mock/forkDetectorStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,7 @@ type ForkDetectorStub struct { SetRollBackNonceCalled func(nonce uint64) ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) } // RestoreToGenesis - @@ -114,6 +116,13 @@ func (fdm *ForkDetectorStub) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorStub) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorStub) IsInterfaceNil() bool { return fdm == nil diff --git a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go index 27c963a9747..0b15e3e59ca 100644 --- a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go +++ b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -36,7 +38,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" @@ -227,7 +228,9 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui cryptoComponents.BlKeyGen = &mock.KeyGenMock{} cryptoComponents.TxKeyGen = &mock.KeyGenMock{} - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestMarshalizer coreComponents.HasherField = integrationTests.TestHasher @@ -358,7 +361,7 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui ChainID: string(integrationTests.ChainID), ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, MiniblocksProvider: &mock.MiniBlocksProviderStub{}, - EpochNotifier: &epochNotifierMock.EpochNotifierStub{}, + EpochNotifier: genericEpochNotifier, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, AppStatusHandler: &statusHandlerMock.AppStatusHandlerMock{}, } diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 7da61a4dcc3..642cb0e267a 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -17,6 +17,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -390,7 +392,9 @@ func hardForkImport( defaults.FillGasMapInternal(gasSchedule, 1) log.Warn("started import process") - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestMarshalizer coreComponents.HasherField = integrationTests.TestHasher @@ -570,7 +574,9 @@ func createHardForkExporter( returnedConfigs[node.ShardCoordinator.SelfId()] = append(returnedConfigs[node.ShardCoordinator.SelfId()], exportConfig) returnedConfigs[node.ShardCoordinator.SelfId()] = append(returnedConfigs[node.ShardCoordinator.SelfId()], keysConfig) - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestTxSignMarshalizer coreComponents.HasherField = integrationTests.TestHasher diff --git a/integrationTests/multiShard/relayedTx/relayedTxV2_test.go b/integrationTests/multiShard/relayedTx/relayedTxV2_test.go index 511bb80f638..172b9dfda1c 100644 --- a/integrationTests/multiShard/relayedTx/relayedTxV2_test.go +++ b/integrationTests/multiShard/relayedTx/relayedTxV2_test.go @@ -97,7 +97,7 @@ func TestRelayedTransactionV2InMultiShardEnvironmentWithSmartContractTX(t *testi time.Sleep(time.Second) finalBalance := big.NewInt(0).Mul(big.NewInt(int64(len(players))), big.NewInt(nrRoundsToTest)) - finalBalance.Mul(finalBalance, sendValue) + finalBalance = big.NewInt(0).Mul(finalBalance, sendValue) checkSCBalance(t, ownerNode, scAddress, receiverAddress1, finalBalance) checkSCBalance(t, ownerNode, scAddress, receiverAddress1, finalBalance) diff --git a/integrationTests/node/getAccount/getAccount_test.go b/integrationTests/node/getAccount/getAccount_test.go index 487c8b1a15a..acb4e92fd75 100644 --- a/integrationTests/node/getAccount/getAccount_test.go +++ b/integrationTests/node/getAccount/getAccount_test.go @@ -7,13 +7,16 @@ import ( chainData "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/blockInfoProviders" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createAccountsRepository(accDB state.AccountsAdapter, blockchain chainData.ChainHandler) state.AccountsRepository { @@ -39,7 +42,9 @@ func TestNode_GetAccountAccountDoesNotExistsShouldRetEmpty(t *testing.T) { accDB, _ := integrationTests.CreateAccountsDB(0, trieStorage) rootHash, _ := accDB.Commit() - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.AddressPubKeyConverterField = integrationTests.TestAddressPubkeyConverter dataComponents := integrationTests.GetDefaultDataComponents() @@ -81,7 +86,9 @@ func TestNode_GetAccountAccountExistsShouldReturn(t *testing.T) { testPubkey := integrationTests.CreateAccount(accDB, testNonce, testBalance) rootHash, _ := accDB.Commit() - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.AddressPubKeyConverterField = testscommon.RealWorldBech32PubkeyConverter dataComponents := integrationTests.GetDefaultDataComponents() diff --git a/integrationTests/realcomponents/processorRunner_test.go b/integrationTests/realcomponents/processorRunner_test.go index 78d0013597e..ce2e60a48d3 100644 --- a/integrationTests/realcomponents/processorRunner_test.go +++ b/integrationTests/realcomponents/processorRunner_test.go @@ -3,8 +3,9 @@ package realcomponents import ( "testing" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/testscommon" ) func TestNewProcessorRunnerAndClose(t *testing.T) { diff --git a/integrationTests/sync/basicSync/basicSync_test.go b/integrationTests/sync/basicSync/basicSync_test.go index 1dfb82dcf80..408262f2297 100644 --- a/integrationTests/sync/basicSync/basicSync_test.go +++ b/integrationTests/sync/basicSync/basicSync_test.go @@ -199,16 +199,21 @@ func testAllNodesHaveSameLastBlock(t *testing.T, nodes []*integrationTests.TestP } func TestSyncWorksInShard_EmptyBlocksNoForks_With_EquivalentProofs(t *testing.T) { + // TODO: remove skip after test is fixed + t.Skip("will be fixed in another PR") + if testing.Short() { t.Skip("this is not a short test") } + // 3 shard nodes and 1 metachain node maxShards := uint32(1) shardId := uint32(0) numNodesPerShard := 3 enableEpochs := integrationTests.CreateEnableEpochsConfig() enableEpochs.EquivalentMessagesEnableEpoch = uint32(0) + enableEpochs.FixedOrderInConsensusEnableEpoch = uint32(0) nodes := make([]*integrationTests.TestProcessorNode, numNodesPerShard+1) connectableNodes := make([]integrationTests.Connectable, 0) @@ -228,6 +233,7 @@ func TestSyncWorksInShard_EmptyBlocksNoForks_With_EquivalentProofs(t *testing.T) NodeShardId: core.MetachainShardId, TxSignPrivKeyShardId: shardId, WithSync: true, + EpochsConfig: &enableEpochs, }) idxProposerMeta := numNodesPerShard nodes[idxProposerMeta] = metachainNode diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go index 8651045eb7e..0d213118e11 100644 --- a/integrationTests/testConsensusNode.go +++ b/integrationTests/testConsensusNode.go @@ -16,7 +16,10 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus/round" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -223,13 +226,6 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { } epochStartTrigger, _ := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) - forkDetector, _ := syncFork.NewShardForkDetector( - roundHandler, - cache.NewTimeCache(time.Second), - &mock.BlockTrackerStub{}, - args.StartTime, - ) - tcn.initRequestersFinder() peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) @@ -241,7 +237,9 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { tcn.initAccountsDB() - coreComponents := GetDefaultCoreComponents(args.EnableEpochsConfig) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(args.EnableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.SyncTimerField = syncer coreComponents.RoundHandlerField = roundHandler coreComponents.InternalMarshalizerField = TestMarshalizer @@ -309,6 +307,20 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { cryptoComponents.SigHandler = sigHandler cryptoComponents.KeysHandlerField = keysHandler + dataComponents := GetDefaultDataComponents() + dataComponents.BlockChain = tcn.ChainHandler + dataComponents.DataPool = dataPool + dataComponents.Store = createTestStore() + + forkDetector, _ := syncFork.NewShardForkDetector( + roundHandler, + cache.NewTimeCache(time.Second), + &mock.BlockTrackerStub{}, + args.StartTime, + enableEpochsHandler, + dataPool.Proofs(), + ) + processComponents := GetDefaultProcessComponents() processComponents.ForkDetect = forkDetector processComponents.ShardCoord = tcn.ShardCoordinator @@ -329,11 +341,6 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { processComponents.ProcessedMiniBlocksTrackerInternal = &testscommon.ProcessedMiniBlocksTrackerStub{} processComponents.SentSignaturesTrackerInternal = &testscommon.SentSignatureTrackerStub{} - dataComponents := GetDefaultDataComponents() - dataComponents.BlockChain = tcn.ChainHandler - dataComponents.DataPool = dataPool - dataComponents.Store = createTestStore() - stateComponents := GetDefaultStateComponents() stateComponents.Accounts = tcn.AccountsDB stateComponents.AccountsAPI = tcn.AccountsDB diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index 57af859a8df..c74566ef5ca 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -34,6 +34,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -648,7 +650,9 @@ func CreateFullGenesisBlocks( gasSchedule := wasmConfig.MakeGasMapForTests() defaults.FillGasMapInternal(gasSchedule, 1) - coreComponents := GetDefaultCoreComponents(enableEpochsConfig) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.HasherField = TestHasher @@ -772,7 +776,9 @@ func CreateGenesisMetaBlock( gasSchedule := wasmConfig.MakeGasMapForTests() defaults.FillGasMapInternal(gasSchedule, 1) - coreComponents := GetDefaultCoreComponents(enableEpochsConfig) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = marshalizer coreComponents.HasherField = hasher coreComponents.Uint64ByteSliceConverterField = uint64Converter @@ -2184,7 +2190,9 @@ func generateValidTx( _ = accnts.SaveAccount(acc) _, _ = accnts.Commit() - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.VmMarshalizerField = TestMarshalizer diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index 6416f8b6c7c..ae812a195c6 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -1272,7 +1272,13 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() } - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.HasherField = TestHasher @@ -2167,23 +2173,21 @@ func (tpn *TestProcessorNode) addMockVm(blockchainHook vmcommon.BlockchainHook) func (tpn *TestProcessorNode) initBlockProcessor() { var err error - if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { - tpn.ForkDetector, _ = processSync.NewShardForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, tpn.NodesSetup.GetStartTime()) - } else { - tpn.ForkDetector, _ = processSync.NewMetaForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, tpn.NodesSetup.GetStartTime()) - } - accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) accountsDb[state.UserAccountsState] = tpn.AccntState accountsDb[state.PeerAccountsState] = tpn.PeerState - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.HasherField = TestHasher coreComponents.Uint64ByteSliceConverterField = TestUint64Converter coreComponents.RoundHandlerField = tpn.RoundHandler - coreComponents.EnableEpochsHandlerField = tpn.EnableEpochsHandler - coreComponents.EpochNotifierField = tpn.EpochNotifier coreComponents.EconomicsDataField = tpn.EconomicsData coreComponents.RoundNotifierField = tpn.RoundNotifier @@ -2192,7 +2196,25 @@ func (tpn *TestProcessorNode) initBlockProcessor() { dataComponents.DataPool = tpn.DataPool dataComponents.BlockChain = tpn.BlockChain - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { + tpn.ForkDetector, _ = processSync.NewShardForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + tpn.NodesSetup.GetStartTime(), + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) + } else { + tpn.ForkDetector, _ = processSync.NewMetaForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + tpn.NodesSetup.GetStartTime(), + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) + } + + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier statusComponents := GetDefaultStatusComponents() @@ -2478,8 +2500,13 @@ func (tpn *TestProcessorNode) initNode() { StatusMetricsField: tpn.StatusMetrics, AppStatusHandlerField: tpn.AppStatusHandler, } - - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.VmMarshalizerField = TestVmMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer @@ -2509,7 +2536,7 @@ func (tpn *TestProcessorNode) initNode() { dataComponents.DataPool = tpn.DataPool dataComponents.Store = tpn.Storage - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) processComponents := GetDefaultProcessComponents() processComponents.BlockProcess = tpn.BlockProcessor @@ -2706,12 +2733,6 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return nil, nil, nil } - err = blockHeader.SetPubKeysBitmap([]byte{1}) - if err != nil { - log.Warn("blockHeader.SetPubKeysBitmap", "error", err.Error()) - return nil, nil, nil - } - currHdr := tpn.BlockChain.GetCurrentBlockHeader() currHdrHash := tpn.BlockChain.GetCurrentBlockHeaderHash() if check.IfNil(currHdr) { @@ -2730,22 +2751,10 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod log.Warn("blockHeader.SetPrevRandSeed", "error", err.Error()) return nil, nil, nil } - sig := []byte("aggregated signature") - err = blockHeader.SetSignature(sig) - if err != nil { - log.Warn("blockHeader.SetSignature", "error", err.Error()) - return nil, nil, nil - } - err = blockHeader.SetRandSeed(sig) + err = tpn.setBlockSignatures(blockHeader) if err != nil { - log.Warn("blockHeader.SetRandSeed", "error", err.Error()) - return nil, nil, nil - } - - err = blockHeader.SetLeaderSignature([]byte("leader sign")) - if err != nil { - log.Warn("blockHeader.SetLeaderSignature", "error", err.Error()) + log.Warn("setBlockSignatures", "error", err.Error()) return nil, nil, nil } @@ -2761,20 +2770,6 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return nil, nil, nil } - previousProof := &dataBlock.HeaderProof{ - PubKeysBitmap: []byte{1}, - AggregatedSignature: sig, - HeaderHash: currHdrHash, - HeaderEpoch: currHdr.GetEpoch(), - HeaderNonce: currHdr.GetNonce(), - HeaderShardId: currHdr.GetShardID(), - } - blockHeader.SetPreviousProof(previousProof) - - _ = tpn.ProofsPool.AddProof(previousProof) - - log.Error("added proof", "currHdrHash", currHdrHash, "node", tpn.OwnAccount.Address) - genesisRound := tpn.BlockChain.GetGenesisHeader().GetRound() err = blockHeader.SetTimeStamp((round - genesisRound) * uint64(tpn.RoundHandler.TimeDuration().Seconds())) if err != nil { @@ -2808,6 +2803,57 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return blockBody, blockHeader, txHashes } +func (tpn *TestProcessorNode) setBlockSignatures(blockHeader data.HeaderHandler) error { + currHdrHash := tpn.BlockChain.GetCurrentBlockHeaderHash() + currHdr := tpn.BlockChain.GetCurrentBlockHeader() + sig := []byte("aggregated signature") + pubKeysBitmap := []byte{1} + + err := blockHeader.SetRandSeed(sig) + if err != nil { + log.Warn("blockHeader.SetRandSeed", "error", err.Error()) + return err + } + + leaderSignature := []byte("leader signature") + err = blockHeader.SetLeaderSignature(leaderSignature) + if err != nil { + log.Warn("blockHeader.SetLeaderSignature", "error", err.Error()) + return err + } + + if common.ShouldBlockHavePrevProof(blockHeader, tpn.EnableEpochsHandler, common.EquivalentMessagesFlag) { + previousProof := &dataBlock.HeaderProof{ + PubKeysBitmap: pubKeysBitmap, + AggregatedSignature: sig, + HeaderHash: currHdrHash, + HeaderEpoch: currHdr.GetEpoch(), + HeaderNonce: currHdr.GetNonce(), + HeaderShardId: currHdr.GetShardID(), + } + blockHeader.SetPreviousProof(previousProof) + + err = tpn.ProofsPool.AddProof(previousProof) + if err != nil { + log.Warn("ProofsPool.AddProof", "currHdrHash", currHdrHash, "node", tpn.OwnAccount.Address, "err", err.Error()) + } + return err + } + + err = blockHeader.SetPubKeysBitmap(pubKeysBitmap) + if err != nil { + log.Warn("blockHeader.SetPubKeysBitmap", "error", err.Error()) + return err + } + + err = blockHeader.SetSignature(sig) + if err != nil { + log.Warn("blockHeader.SetSignature", "error", err.Error()) + return err + } + return err +} + // BroadcastBlock broadcasts the block and body to the connected peers func (tpn *TestProcessorNode) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler, publicKey crypto.PublicKey) { _ = tpn.BroadcastMessenger.BroadcastBlock(body, header) @@ -3092,8 +3138,9 @@ func (tpn *TestProcessorNode) initBlockTracker() { func (tpn *TestProcessorNode) initHeaderValidator() { argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: TestHasher, - Marshalizer: TestMarshalizer, + Hasher: TestHasher, + Marshalizer: TestMarshalizer, + EnableEpochsHandler: tpn.EnableEpochsHandler, } tpn.HeaderValidator, _ = block.NewHeaderValidator(argsHeaderValidator) @@ -3293,10 +3340,7 @@ func CreateEnableEpochsConfig() config.EnableEpochs { } // GetDefaultCoreComponents - -func GetDefaultCoreComponents(enableEpochsConfig config.EnableEpochs) *mock.CoreComponentsStub { - genericEpochNotifier := forking.NewGenericEpochNotifier() - enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) - +func GetDefaultCoreComponents(enableEpochsHandler common.EnableEpochsHandler, epochNotifier process.EpochNotifier) *mock.CoreComponentsStub { return &mock.CoreComponentsStub{ InternalMarshalizerField: TestMarshalizer, TxMarshalizerField: TestTxSignMarshalizer, @@ -3323,7 +3367,7 @@ func GetDefaultCoreComponents(enableEpochsConfig config.EnableEpochs) *mock.Core RaterField: &testscommon.RaterMock{}, GenesisNodesSetupField: &genesisMocks.NodesSetupStub{}, GenesisTimeField: time.Time{}, - EpochNotifierField: genericEpochNotifier, + EpochNotifierField: epochNotifier, EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, TxVersionCheckField: versioning.NewTxVersionChecker(MinTransactionVersion), ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, @@ -3446,12 +3490,15 @@ func GetDefaultStatusComponents() *mock.StatusComponentsStub { } // getDefaultBootstrapComponents - -func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator) *mainFactoryMocks.BootstrapComponentsStub { +func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator, handler common.EnableEpochsHandler) *mainFactoryMocks.BootstrapComponentsStub { var versionedHeaderFactory nodeFactory.VersionedHeaderFactory headerVersionHandler := &testscommon.HeaderVersionHandlerStub{ GetVersionCalled: func(epoch uint32) string { - return "2" + if handler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, epoch) { + return "2" + } + return "1" }, } versionedHeaderFactory, _ = hdrFactory.NewShardHeaderFactory(headerVersionHandler) diff --git a/integrationTests/testSyncNode.go b/integrationTests/testSyncNode.go index 31c2ac46111..385034d6304 100644 --- a/integrationTests/testSyncNode.go +++ b/integrationTests/testSyncNode.go @@ -5,7 +5,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/common" + + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/provider" @@ -45,27 +47,25 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { accountsDb[state.UserAccountsState] = tpn.AccntState accountsDb[state.PeerAccountsState] = tpn.PeerState - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.HasherField = TestHasher coreComponents.Uint64ByteSliceConverterField = TestUint64Converter coreComponents.EpochNotifierField = tpn.EpochNotifier coreComponents.RoundNotifierField = tpn.RoundNotifier - coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { - if flag == common.RefactorPeersMiniBlocksFlag { - return UnreachableEpoch - } - return 0 - }, - } dataComponents := GetDefaultDataComponents() dataComponents.Store = tpn.Storage dataComponents.DataPool = tpn.DataPool dataComponents.BlockChain = tpn.BlockChain - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier statusComponents := GetDefaultStatusComponents() @@ -108,7 +108,13 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { } if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { - tpn.ForkDetector, _ = sync.NewMetaForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, 0) + tpn.ForkDetector, _ = sync.NewMetaForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + 0, + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) argumentsBase.ForkDetector = tpn.ForkDetector argumentsBase.TxCoordinator = &mock.TransactionCoordinatorMock{} arguments := block.ArgMetaProcessor{ @@ -129,7 +135,13 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { tpn.BlockProcessor, err = block.NewMetaProcessor(arguments) } else { - tpn.ForkDetector, _ = sync.NewShardForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, 0) + tpn.ForkDetector, _ = sync.NewShardForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + 0, + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) argumentsBase.ForkDetector = tpn.ForkDetector argumentsBase.BlockChainHook = tpn.BlockchainHook argumentsBase.TxCoordinator = tpn.TxCoordinator diff --git a/integrationTests/vm/delegation/changeOwner_test.go b/integrationTests/vm/delegation/changeOwner_test.go index c634452ea9c..47b0ffa6d12 100644 --- a/integrationTests/vm/delegation/changeOwner_test.go +++ b/integrationTests/vm/delegation/changeOwner_test.go @@ -6,12 +6,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon" ) var ( @@ -55,7 +56,7 @@ func TestDelegationChangeOwnerOnAccountHandler(t *testing.T) { // verify the new owner is still the delegator verifyDelegatorsStake(t, tpn, "getUserActiveStake", [][]byte{newOwner}, userAccount.AddressBytes(), big.NewInt(2000)) - //get the SC delegation account + // get the SC delegation account account, err := tpn.AccntState.LoadAccount(scAddress) require.Nil(t, err) @@ -91,7 +92,7 @@ func testDelegationChangeOwnerOnAccountHandler(t *testing.T, epochToTest uint32) changeOwner(t, tpn, firstOwner, newOwner, delegationScAddress) verifyDelegatorsStake(t, tpn, "getUserActiveStake", [][]byte{newOwner}, delegationScAddress, big.NewInt(2000)) - //get the SC delegation account + // get the SC delegation account account, err := tpn.AccntState.LoadAccount(delegationScAddress) require.Nil(t, err) diff --git a/integrationTests/vm/staking/metaBlockProcessorCreator.go b/integrationTests/vm/staking/metaBlockProcessorCreator.go index 759458cf30e..cbb52101531 100644 --- a/integrationTests/vm/staking/metaBlockProcessorCreator.go +++ b/integrationTests/vm/staking/metaBlockProcessorCreator.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/metachain" @@ -217,8 +218,9 @@ func createGenesisMetaBlock() *block.MetaBlock { func createHeaderValidator(coreComponents factory.CoreComponentsHolder) epochStart.HeaderValidator { argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) return headerValidator diff --git a/integrationTests/vm/wasm/upgrades/upgrades_test.go b/integrationTests/vm/wasm/upgrades/upgrades_test.go index c6313d65e73..09b3c0ee49c 100644 --- a/integrationTests/vm/wasm/upgrades/upgrades_test.go +++ b/integrationTests/vm/wasm/upgrades/upgrades_test.go @@ -6,12 +6,13 @@ import ( "math/big" "testing" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func TestUpgrades_Hello(t *testing.T) { @@ -212,7 +213,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { require.Nil(t, err) scAddress, _ := network.ShardNode.BlockchainHook.NewAddress(alice.Address, 0, factory.WasmVirtualMachine) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{24}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) // Upgrade as Bob - upgrade should fail, since Alice is the owner @@ -225,7 +226,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{24}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) // Now upgrade as Alice, should work @@ -238,7 +239,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{42}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) } @@ -269,7 +270,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { require.Nil(t, err) scAddress, _ := network.ShardNode.BlockchainHook.NewAddress(alice.Address, 0, factory.WasmVirtualMachine) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{1}, query(t, network.ShardNode, scAddress, "get")) // Increment the counter (could be either Bob or Alice) @@ -282,7 +283,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{2}, query(t, network.ShardNode, scAddress, "get")) // Upgrade as Bob - upgrade should fail, since Alice is the owner (counter.init() not executed, state not reset) @@ -295,7 +296,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{2}, query(t, network.ShardNode, scAddress, "get")) // Now upgrade as Alice, should work (state is reset by counter.init()) @@ -308,7 +309,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{1}, query(t, network.ShardNode, scAddress, "get")) } diff --git a/node/mock/epochStartNotifier.go b/node/mock/epochStartNotifier.go index c4675a37401..8c6fb4c51e8 100644 --- a/node/mock/epochStartNotifier.go +++ b/node/mock/epochStartNotifier.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/node/mock/forkDetectorMock.go b/node/mock/forkDetectorMock.go index d681b976d7d..52ffeba3384 100644 --- a/node/mock/forkDetectorMock.go +++ b/node/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,7 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) } // RestoreToGenesis - @@ -87,6 +89,13 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 4f2a3661ece..e29f507661e 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -222,7 +222,16 @@ func (bp *baseProcessor) checkBlockValidity( return process.ErrEpochDoesNotMatch } - return nil + return bp.checkPrevProofValidity(currentBlockHeader, headerHandler) +} + +func (bp *baseProcessor) checkPrevProofValidity(prevHeader, headerHandler data.HeaderHandler) error { + if !common.ShouldBlockHavePrevProof(headerHandler, bp.enableEpochsHandler, common.EquivalentMessagesFlag) { + return nil + } + + prevProof := headerHandler.GetPreviousProof() + return common.VerifyProofAgainstHeader(prevProof, prevHeader) } // checkScheduledRootHash checks if the scheduled root hash from the given header is the same with the current user accounts state root hash @@ -538,6 +547,7 @@ func checkProcessorParameters(arguments ArgBaseProcessor) error { common.ScheduledMiniBlocksFlag, common.StakingV2Flag, common.CurrentRandomnessOnSortingFlag, + common.EquivalentMessagesFlag, }) if err != nil { return err @@ -625,7 +635,7 @@ func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) (ma } if bp.hasMissingProof(headerInfo, hdrHash) { - return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hdrHash) + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(hdrHash))) } hdrsForCurrentBlock[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlock[headerInfo.hdr.GetShardID()], headerInfo.hdr) @@ -650,7 +660,7 @@ func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool } if bp.hasMissingProof(headerInfo, metaBlockHash) { - return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, metaBlockHash) + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(metaBlockHash))) } hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()], @@ -831,7 +841,6 @@ func isPartiallyExecuted( ) bool { processedMiniBlockInfo := processedMiniBlocksDestMeInfo[string(miniBlockHeaderHandler.GetHash())] return processedMiniBlockInfo != nil && !processedMiniBlockInfo.FullyProcessed - } // check if header has the same miniblocks as presented in body @@ -2191,3 +2200,17 @@ func (bp *baseProcessor) checkSentSignaturesAtCommitTime(header data.HeaderHandl return nil } + +func (bp *baseProcessor) addPrevProofIfNeeded(header data.HeaderHandler) error { + if !common.ShouldBlockHavePrevProof(header, bp.enableEpochsHandler, common.EquivalentMessagesFlag) { + return nil + } + + prevBlockProof, err := bp.proofsPool.GetProof(bp.shardCoordinator.SelfId(), header.GetPrevHash()) + if err != nil { + return err + } + + header.SetPreviousProof(prevBlockProof) + return nil +} diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index 017f7b3e1d0..2e53e5aa699 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -79,8 +79,9 @@ func createArgBaseProcessor( ) blproc.ArgBaseProcessor { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) @@ -143,8 +144,8 @@ func createTestBlockchain() *testscommon.ChainHandlerStub { } func generateTestCache() storage.Cacher { - cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) - return cache + c, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + return c } func generateTestUnit() storage.Storer { @@ -1268,8 +1269,9 @@ func TestBaseProcessor_SaveLastNotarizedHdrShardGood(t *testing.T) { sp, _ := blproc.NewShardProcessor(arguments) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) sp.SetHeaderValidator(headerValidator) @@ -1302,8 +1304,9 @@ func TestBaseProcessor_SaveLastNotarizedHdrMetaGood(t *testing.T) { sp, _ := blproc.NewShardProcessor(arguments) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) sp.SetHeaderValidator(headerValidator) diff --git a/process/block/export_test.go b/process/block/export_test.go index 2332115613c..56365b44a22 100644 --- a/process/block/export_test.go +++ b/process/block/export_test.go @@ -99,8 +99,9 @@ func NewShardProcessorEmptyWith3shards( nodesCoordinator := shardingMocks.NewNodesCoordinatorMock() argsHeaderValidator := ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } hdrValidator, _ := NewHeaderValidator(argsHeaderValidator) diff --git a/process/block/headerValidator.go b/process/block/headerValidator.go index 9459280c847..dd5a22d9d24 100644 --- a/process/block/headerValidator.go +++ b/process/block/headerValidator.go @@ -8,6 +8,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" ) @@ -15,13 +17,15 @@ var _ process.HeaderConstructionValidator = (*headerValidator)(nil) // ArgsHeaderValidator are the arguments needed to create a new header validator type ArgsHeaderValidator struct { - Hasher hashing.Hasher - Marshalizer marshal.Marshalizer + Hasher hashing.Hasher + Marshalizer marshal.Marshalizer + EnableEpochsHandler core.EnableEpochsHandler } type headerValidator struct { - hasher hashing.Hasher - marshalizer marshal.Marshalizer + hasher hashing.Hasher + marshalizer marshal.Marshalizer + enableEpochsHandler core.EnableEpochsHandler } // NewHeaderValidator returns a new header validator @@ -32,10 +36,14 @@ func NewHeaderValidator(args ArgsHeaderValidator) (*headerValidator, error) { if check.IfNil(args.Marshalizer) { return nil, process.ErrNilMarshalizer } + if check.IfNil(args.EnableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } return &headerValidator{ - hasher: args.Hasher, - marshalizer: args.Marshalizer, + hasher: args.Hasher, + marshalizer: args.Marshalizer, + enableEpochsHandler: args.EnableEpochsHandler, }, nil } @@ -87,9 +95,15 @@ func (h *headerValidator) IsHeaderConstructionValid(currHeader, prevHeader data. return process.ErrRandSeedDoesNotMatch } - // TODO: check here if proof from currHeader is valid for prevHeader + return h.verifyProofForBlock(prevHeader, currHeader.GetPreviousProof()) +} + +func (h *headerValidator) verifyProofForBlock(header data.HeaderHandler, proof data.HeaderProofHandler) error { + if !h.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + return nil + } - return nil + return common.VerifyProofAgainstHeader(proof, header) } // IsInterfaceNil returns if underlying object is true diff --git a/process/block/interceptedBlocks/common.go b/process/block/interceptedBlocks/common.go index 90a604dba23..b70e48e9426 100644 --- a/process/block/interceptedBlocks/common.go +++ b/process/block/interceptedBlocks/common.go @@ -101,21 +101,20 @@ func checkHeaderHandler(hdr data.HeaderHandler, enableEpochsHandler common.Enabl } func checkProofIntegrity(hdr data.HeaderHandler, enableEpochsHandler common.EnableEpochsHandler) error { - equivalentMessagesEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.GetEpoch()) - prevHeaderProof := hdr.GetPreviousProof() nilPreviousProof := check.IfNilReflect(prevHeaderProof) - missingProof := nilPreviousProof && equivalentMessagesEnabled - unexpectedProof := !nilPreviousProof && !equivalentMessagesEnabled - hasProof := !nilPreviousProof && equivalentMessagesEnabled + shouldHavePrevProof := common.ShouldBlockHavePrevProof(hdr, enableEpochsHandler, common.EquivalentMessagesFlag) + missingPrevProof := nilPreviousProof && shouldHavePrevProof + unexpectedPrevProof := !nilPreviousProof && !shouldHavePrevProof + hasPrevProof := !nilPreviousProof && !missingPrevProof - if missingProof { - return process.ErrMissingHeaderProof + if missingPrevProof { + return process.ErrMissingPrevHeaderProof } - if unexpectedProof { + if unexpectedPrevProof { return process.ErrUnexpectedHeaderProof } - if hasProof && isIncompleteProof(prevHeaderProof) { + if hasPrevProof && isIncompleteProof(prevHeaderProof) { return process.ErrInvalidHeaderProof } diff --git a/process/block/interceptedBlocks/interceptedBlockHeader.go b/process/block/interceptedBlocks/interceptedBlockHeader.go index cde4be46170..9aac8ceabc7 100644 --- a/process/block/interceptedBlocks/interceptedBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedBlockHeader.go @@ -68,6 +68,9 @@ func (inHdr *InterceptedHeader) processFields(txBuff []byte) { // CheckValidity checks if the received header is valid (not nil fields, valid sig and so on) func (inHdr *InterceptedHeader) CheckValidity() error { + // TODO: remove this log after debugging + log.Debug("CheckValidity for header with", "epoch", inHdr.hdr.GetEpoch(), "hash", logger.DisplayByteSlice(inHdr.hash)) + err := inHdr.integrityVerifier.Verify(inHdr.hdr) if err != nil { return err diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof.go b/process/block/interceptedBlocks/interceptedEquivalentProof.go index a7937a5aef2..2ba09e30a04 100644 --- a/process/block/interceptedBlocks/interceptedEquivalentProof.go +++ b/process/block/interceptedBlocks/interceptedEquivalentProof.go @@ -7,13 +7,16 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-vm-v1_2-go/ipc/marshaling" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) const interceptedEquivalentProofType = "intercepted equivalent proof" @@ -22,9 +25,11 @@ const interceptedEquivalentProofType = "intercepted equivalent proof" type ArgInterceptedEquivalentProof struct { DataBuff []byte Marshaller marshal.Marshalizer + Hasher hashing.Hasher ShardCoordinator sharding.Coordinator HeaderSigVerifier consensus.HeaderSigVerifier Proofs dataRetriever.ProofsPool + Headers dataRetriever.HeadersPool } type interceptedEquivalentProof struct { @@ -32,6 +37,10 @@ type interceptedEquivalentProof struct { isForCurrentShard bool headerSigVerifier consensus.HeaderSigVerifier proofsPool dataRetriever.ProofsPool + headersPool dataRetriever.HeadersPool + marshaller marshaling.Marshalizer + hasher hashing.Hasher + hash []byte } // NewInterceptedEquivalentProof returns a new instance of interceptedEquivalentProof @@ -46,11 +55,17 @@ func NewInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) (*interce return nil, err } + hash := args.Hasher.Compute(string(args.DataBuff)) + return &interceptedEquivalentProof{ proof: equivalentProof, isForCurrentShard: extractIsForCurrentShard(args.ShardCoordinator, equivalentProof), headerSigVerifier: args.HeaderSigVerifier, proofsPool: args.Proofs, + headersPool: args.Headers, + marshaller: args.Marshaller, + hasher: args.Hasher, + hash: hash, }, nil } @@ -70,6 +85,12 @@ func checkArgInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) erro if check.IfNil(args.Proofs) { return process.ErrNilProofsPool } + if check.IfNil(args.Headers) { + return process.ErrNilHeadersDataPool + } + if check.IfNil(args.Hasher) { + return process.ErrNilHasher + } return nil } @@ -86,6 +107,7 @@ func createEquivalentProof(marshaller marshal.Marshalizer, buff []byte) (*block. "header shard", headerProof.HeaderShardId, "header epoch", headerProof.HeaderEpoch, "header nonce", headerProof.HeaderNonce, + "header round", headerProof.HeaderRound, "bitmap", logger.DisplayByteSlice(headerProof.PubKeysBitmap), "signature", logger.DisplayByteSlice(headerProof.AggregatedSignature), ) @@ -95,6 +117,10 @@ func createEquivalentProof(marshaller marshal.Marshalizer, buff []byte) (*block. func extractIsForCurrentShard(shardCoordinator sharding.Coordinator, equivalentProof *block.HeaderProof) bool { proofShardId := equivalentProof.GetHeaderShardId() + if shardCoordinator.SelfId() == core.MetachainShardId { + return true + } + if proofShardId == core.MetachainShardId { return true } @@ -114,6 +140,8 @@ func (iep *interceptedEquivalentProof) CheckValidity() error { return proofscache.ErrAlreadyExistingEquivalentProof } + // TODO: make sure proof fields (besides ones used to verify signature) should be checked on processing. + return iep.headerSigVerifier.VerifyHeaderProof(iep.proof) } @@ -140,7 +168,7 @@ func (iep *interceptedEquivalentProof) IsForCurrentShard() bool { // Hash returns the header hash the proof belongs to func (iep *interceptedEquivalentProof) Hash() []byte { - return iep.proof.HeaderHash + return iep.hash } // Type returns the type of this intercepted data diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go index b0a8cd6c9c9..85262c297bc 100644 --- a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go +++ b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go @@ -1,26 +1,34 @@ package interceptedBlocks import ( - "bytes" "errors" "fmt" "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/mock" proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) var ( expectedErr = errors.New("expected error") testMarshaller = &marshallerMock.MarshalizerMock{} + providedEpoch = uint32(123) + providedNonce = uint64(345) + providedShard = uint32(0) + providedRound = uint64(123456) ) func createMockDataBuff() []byte { @@ -28,9 +36,10 @@ func createMockDataBuff() []byte { PubKeysBitmap: []byte("bitmap"), AggregatedSignature: []byte("sig"), HeaderHash: []byte("hash"), - HeaderEpoch: 123, - HeaderNonce: 345, - HeaderShardId: 0, + HeaderEpoch: providedEpoch, + HeaderNonce: providedNonce, + HeaderShardId: providedShard, + HeaderRound: providedRound, } dataBuff, _ := testMarshaller.Marshal(proof) @@ -44,6 +53,21 @@ func createMockArgInterceptedEquivalentProof() ArgInterceptedEquivalentProof { ShardCoordinator: &mock.ShardCoordinatorMock{}, HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, Proofs: &dataRetriever.ProofsPoolMock{}, + Hasher: &hashingMocks.HasherMock{}, + Headers: &pool.HeadersPoolStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + EpochField: providedEpoch, + RoundField: providedRound, + GetNonceCalled: func() uint64 { + return providedNonce + }, + GetShardIDCalled: func() uint32 { + return providedShard + }, + }, nil + }, + }, } } @@ -105,6 +129,24 @@ func TestNewInterceptedEquivalentProof(t *testing.T) { require.Equal(t, process.ErrNilProofsPool, err) require.Nil(t, iep) }) + t.Run("nil headers pool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Headers = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHeadersDataPool, err) + require.Nil(t, iep) + }) + t.Run("nil Hasher should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Hasher = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHasher, err) + require.Nil(t, iep) + }) t.Run("unmarshal error should error", func(t *testing.T) { t.Parallel() @@ -146,7 +188,6 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) { err = iep.CheckValidity() require.Equal(t, ErrInvalidProof, err) }) - t.Run("already exiting proof should error", func(t *testing.T) { t.Parallel() @@ -163,7 +204,6 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) { err = iep.CheckValidity() require.Equal(t, proofscache.ErrAlreadyExistingEquivalentProof, err) }) - t.Run("should work", func(t *testing.T) { t.Parallel() @@ -245,11 +285,12 @@ func TestInterceptedEquivalentProof_Getters(t *testing.T) { } args := createMockArgInterceptedEquivalentProof() args.DataBuff, _ = args.Marshaller.Marshal(proof) + hash := args.Hasher.Compute(string(args.DataBuff)) iep, err := NewInterceptedEquivalentProof(args) require.NoError(t, err) require.Equal(t, proof, iep.GetProof()) // pointer testing - require.True(t, bytes.Equal(proof.HeaderHash, iep.Hash())) + require.Equal(t, hash, iep.Hash()) require.Equal(t, [][]byte{proof.HeaderHash}, iep.Identifiers()) require.Equal(t, interceptedEquivalentProofType, iep.Type()) expectedStr := fmt.Sprintf("bitmap=%s, signature=%s, hash=%s, epoch=123, shard=0, nonce=345", diff --git a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go index c3f92781e7e..d57732f56f1 100644 --- a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go @@ -88,6 +88,8 @@ func (imh *InterceptedMetaHeader) HeaderHandler() data.HeaderHandler { // CheckValidity checks if the received meta header is valid (not nil fields, valid sig and so on) func (imh *InterceptedMetaHeader) CheckValidity() error { + log.Debug("CheckValidity for header with", "epoch", imh.hdr.GetEpoch(), "hash", logger.DisplayByteSlice(imh.hash)) + err := imh.integrity() if err != nil { return err diff --git a/process/block/metablock.go b/process/block/metablock.go index fbd963f4da4..5026766749d 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -3,6 +3,7 @@ package block import ( "bytes" "encoding/hex" + "errors" "fmt" "math/big" "sync" @@ -13,6 +14,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/headerVersionData" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -22,7 +25,6 @@ import ( "github.com/multiversx/mx-chain-go/process/block/helpers" "github.com/multiversx/mx-chain-go/process/block/processedMb" "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" ) const firstHeaderNonce = uint64(1) @@ -202,7 +204,7 @@ func (mp *metaProcessor) ProcessBlock( err := mp.checkBlockValidity(headerHandler, bodyHandler) if err != nil { - if err == process.ErrBlockHashDoesNotMatch { + if errors.Is(err, process.ErrBlockHashDoesNotMatch) { log.Debug("requested missing meta header", "hash", headerHandler.GetPrevHash(), "for shard", headerHandler.GetShardID(), @@ -425,9 +427,36 @@ func (mp *metaProcessor) checkProofsForShardData(header *block.MetaBlock) error // TODO: consider the validation of the proof: // compare the one from proofsPool with what shardData.CurrentSignature and shardData.CurrentPubKeysBitmap hold // if they are different, verify the proof received on header + + shardHeader, ok := mp.hdrsForCurrBlock.hdrHashAndInfo[string(shardData.HeaderHash)] + if !ok { + return fmt.Errorf("%w for header hash %s", process.ErrMissingHeader, hex.EncodeToString(shardData.HeaderHash)) + } + + if !mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, shardHeader.hdr.GetEpoch()) { + continue + } + if !mp.proofsPool.HasProof(shardData.ShardID, shardData.HeaderHash) { return fmt.Errorf("%w for header hash %s", process.ErrMissingHeaderProof, hex.EncodeToString(shardData.HeaderHash)) } + + shardHeadersStorer, err := mp.store.GetStorer(dataRetriever.BlockHeaderUnit) + if err != nil { + return err + } + + prevProof := shardData.GetPreviousProof() + headersPool := mp.dataPool.Headers() + prevHeader, err := common.GetHeader(prevProof.GetHeaderHash(), headersPool, shardHeadersStorer, mp.marshalizer) + if err != nil { + return err + } + + err = common.VerifyProofAgainstHeader(prevProof, prevHeader) + if err != nil { + return err + } } return nil @@ -737,6 +766,33 @@ func (mp *metaProcessor) RestoreBlockIntoPools(headerHandler data.HeaderHandler, return nil } +func (mp *metaProcessor) updateHeaderForEpochStartIfNeeded(metaHdr *block.MetaBlock) error { + isEpochStart := mp.epochStartTrigger.IsEpochStart() + if !isEpochStart { + return nil + } + return mp.updateEpochStartHeader(metaHdr) +} + +func (mp *metaProcessor) createBody(metaHdr *block.MetaBlock, haveTime func() bool) (data.BodyHandler, error) { + isEpochStart := mp.epochStartTrigger.IsEpochStart() + var body data.BodyHandler + var err error + if isEpochStart { + body, err = mp.createEpochStartBody(metaHdr) + if err != nil { + return nil, err + } + } else { + body, err = mp.createBlockBody(metaHdr, haveTime) + if err != nil { + return nil, err + } + } + + return body, nil +} + // CreateBlock creates the final block and header for the current round func (mp *metaProcessor) CreateBlock( initialHdr data.HeaderHandler, @@ -770,21 +826,19 @@ func (mp *metaProcessor) CreateBlock( return nil, nil, err } - if mp.epochStartTrigger.IsEpochStart() { - err = mp.updateEpochStartHeader(metaHdr) - if err != nil { - return nil, nil, err - } + err = mp.updateHeaderForEpochStartIfNeeded(metaHdr) + if err != nil { + return nil, nil, err + } - body, err = mp.createEpochStartBody(metaHdr) - if err != nil { - return nil, nil, err - } - } else { - body, err = mp.createBlockBody(metaHdr, haveTime) - if err != nil { - return nil, nil, err - } + err = mp.addPrevProofIfNeeded(metaHdr) + if err != nil { + return nil, nil, err + } + + body, err = mp.createBody(metaHdr, haveTime) + if err != nil { + return nil, nil, err } body, err = mp.applyBodyToHeader(metaHdr, body) @@ -1149,7 +1203,7 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( // shard header must be processed completely errAccountState := mp.accountsDB[state.UserAccountsState].RevertToSnapshot(snapshot) if errAccountState != nil { - // TODO: evaluate if reloading the trie from disk will might solve the problem + // TODO: evaluate if reloading the trie from disk might solve the problem log.Warn("accounts.RevertToSnapshot", "error", errAccountState.Error()) } continue @@ -2142,11 +2196,11 @@ func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { continue } - isBlockAfterEquivalentMessagesFlag := check.IfNil(headerInfo.hdr) && + isBlockAfterEquivalentMessagesFlag := !check.IfNil(headerInfo.hdr) && mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) hasMissingShardHdrProof := isBlockAfterEquivalentMessagesFlag && !mp.proofsPool.HasProof(headerInfo.hdr.GetShardID(), []byte(hdrHash)) if hasMissingShardHdrProof { - return nil, fmt.Errorf("%w for shard header with hash %s", process.ErrMissingHeaderProof, hdrHash) + return nil, fmt.Errorf("%w for shard header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(hdrHash))) } shardHdr, ok := headerInfo.hdr.(data.ShardHeaderHandler) @@ -2463,6 +2517,7 @@ func (mp *metaProcessor) CreateNewHeader(round uint64, nonce uint64) (data.Heade } mp.roundNotifier.CheckRound(header) + mp.epochNotifier.CheckEpoch(header) err = metaHeader.SetNonce(nonce) if err != nil { diff --git a/process/block/metablock_test.go b/process/block/metablock_test.go index c78f2c5b039..d26f5074eae 100644 --- a/process/block/metablock_test.go +++ b/process/block/metablock_test.go @@ -13,6 +13,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" @@ -36,8 +39,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockComponentHolders() ( @@ -91,8 +92,9 @@ func createMockMetaArguments( ) blproc.ArgMetaProcessor { argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: &mock.HasherStub{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &mock.HasherStub{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) @@ -1956,8 +1958,9 @@ func TestMetaProcessor_CheckShardHeadersValidity(t *testing.T) { arguments.BlockTracker = mock.NewBlockTrackerMock(bootstrapComponents.ShardCoordinator(), startHeaders) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hash, - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hash, + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } arguments.HeaderValidator, _ = blproc.NewHeaderValidator(argsHeaderValidator) diff --git a/process/block/shardblock.go b/process/block/shardblock.go index d35ed73aa6b..8541f24f060 100644 --- a/process/block/shardblock.go +++ b/process/block/shardblock.go @@ -296,6 +296,15 @@ func (sp *shardProcessor) ProcessBlock( if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { // check proofs for cross notarized metablocks for _, metaBlockHash := range header.GetMetaBlockHashes() { + hInfo, ok := sp.hdrsForCurrBlock.hdrHashAndInfo[string(metaBlockHash)] + if !ok { + return fmt.Errorf("%w for header hash %s", process.ErrMissingHeader, hex.EncodeToString(metaBlockHash)) + } + + if !sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hInfo.hdr.GetEpoch()) { + continue + } + if !sp.proofsPool.HasProof(core.MetachainShardId, metaBlockHash) { return fmt.Errorf("%w for header hash %s", process.ErrMissingHeaderProof, hex.EncodeToString(metaBlockHash)) } @@ -878,6 +887,11 @@ func (sp *shardProcessor) CreateBlock( } } + err = sp.addPrevProofIfNeeded(shardHdr) + if err != nil { + return nil, nil, err + } + sp.epochNotifier.CheckEpoch(shardHdr) sp.blockChainHook.SetCurrentHeader(shardHdr) body, processedMiniBlocksDestMeInfo, err := sp.createBlockBody(shardHdr, haveTime) @@ -1436,6 +1450,7 @@ func (sp *shardProcessor) CreateNewHeader(round uint64, nonce uint64) (data.Head } sp.roundNotifier.CheckRound(header) + sp.epochNotifier.CheckEpoch(header) err = shardHeader.SetNonce(nonce) if err != nil { diff --git a/process/errors.go b/process/errors.go index 395ebf17620..4f786e86e39 100644 --- a/process/errors.go +++ b/process/errors.go @@ -1251,9 +1251,6 @@ var ErrEmptyChainParametersConfiguration = errors.New("empty chain parameters co // ErrNoMatchingConfigForProvidedEpoch signals that there is no matching configuration for the provided epoch var ErrNoMatchingConfigForProvidedEpoch = errors.New("no matching configuration") -// ErrInvalidHeader is raised when header is invalid -var ErrInvalidHeader = errors.New("header is invalid") - // ErrNilHeaderProof signals that a nil header proof has been provided var ErrNilHeaderProof = errors.New("nil header proof") @@ -1269,6 +1266,9 @@ var ErrInvalidInterceptedData = errors.New("invalid intercepted data") // ErrMissingHeaderProof signals that the proof for the header is missing var ErrMissingHeaderProof = errors.New("missing header proof") +// ErrMissingPrevHeaderProof signals that the proof for the previous header is missing +var ErrMissingPrevHeaderProof = errors.New("missing previous header proof") + // ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") diff --git a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go index bc167e0dab5..271f2ac26aa 100644 --- a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go @@ -913,7 +913,13 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() } func (bicf *baseInterceptorsContainerFactory) createOneShardEquivalentProofsInterceptor(topic string) (process.Interceptor, error) { - equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(*bicf.argInterceptorFactory, bicf.dataPool.Proofs()) + args := interceptorFactory.ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: *bicf.argInterceptorFactory, + ProofsPool: bicf.dataPool.Proofs(), + HeadersPool: bicf.dataPool.Headers(), + Storage: bicf.store, + } + equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(args) marshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() argProcessor := processor.ArgEquivalentProofsInterceptorProcessor{ diff --git a/process/headerCheck/errors.go b/process/headerCheck/errors.go index b808de98518..e33c2b56d94 100644 --- a/process/headerCheck/errors.go +++ b/process/headerCheck/errors.go @@ -29,3 +29,6 @@ var ErrProofShardMismatch = errors.New("proof shard mismatch") // ErrProofHeaderHashMismatch signals that the proof header hash does not match the header hash var ErrProofHeaderHashMismatch = errors.New("proof header hash mismatch") + +// ErrProofNotExpected signals that the proof is not expected +var ErrProofNotExpected = errors.New("proof not expected") diff --git a/process/headerCheck/headerSignatureVerify.go b/process/headerCheck/headerSignatureVerify.go index 50bc3ff42ac..a1f838981ad 100644 --- a/process/headerCheck/headerSignatureVerify.go +++ b/process/headerCheck/headerSignatureVerify.go @@ -209,7 +209,7 @@ func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { return hsv.VerifySignatureForHash(headerCopy, hash, bitmap, sig) } -func verifyPrevProofForHeader(header data.HeaderHandler) error { +func verifyPrevProofForHeaderIntegrity(header data.HeaderHandler) error { prevProof := header.GetPreviousProof() if check.IfNilReflect(prevProof) { return process.ErrNilHeaderProof @@ -255,7 +255,15 @@ func (hsv *HeaderSigVerifier) VerifySignatureForHash(header data.HeaderHandler, // VerifyHeaderWithProof checks if the proof on the header is correct func (hsv *HeaderSigVerifier) VerifyHeaderWithProof(header data.HeaderHandler) error { - err := verifyPrevProofForHeader(header) + // first block for transition to equivalent proofs consensus does not have a previous proof + if !common.ShouldBlockHavePrevProof(header, hsv.enableEpochsHandler, common.EquivalentMessagesFlag) { + if prevProof := header.GetPreviousProof(); !check.IfNilReflect(prevProof) { + return ErrProofNotExpected + } + return nil + } + + err := verifyPrevProofForHeaderIntegrity(header) if err != nil { return err } @@ -285,6 +293,7 @@ func (hsv *HeaderSigVerifier) VerifyHeaderProof(proofHandler data.HeaderProofHan return err } + // TODO: add a new method to get consensus signers that does not require the header and only works with the proof // round, prevHash and prevRandSeed could be removed when we remove fallback validation and we don't need backwards compatibility // (e.g new binary from epoch x forward) consensusPubKeys, err := hsv.getConsensusSigners( diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory.go b/process/interceptors/factory/interceptedEquivalentProofsFactory.go index 4c5694d1e4d..736dc17c0e5 100644 --- a/process/interceptors/factory/interceptedEquivalentProofsFactory.go +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory.go @@ -1,7 +1,9 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -9,20 +11,34 @@ import ( "github.com/multiversx/mx-chain-go/sharding" ) +// ArgInterceptedEquivalentProofsFactory is the DTO used to create a new instance of interceptedEquivalentProofsFactory +type ArgInterceptedEquivalentProofsFactory struct { + ArgInterceptedDataFactory + ProofsPool dataRetriever.ProofsPool + HeadersPool dataRetriever.HeadersPool + Storage dataRetriever.StorageService +} + type interceptedEquivalentProofsFactory struct { marshaller marshal.Marshalizer shardCoordinator sharding.Coordinator headerSigVerifier consensus.HeaderSigVerifier proofsPool dataRetriever.ProofsPool + headersPool dataRetriever.HeadersPool + storage dataRetriever.StorageService + hasher hashing.Hasher } // NewInterceptedEquivalentProofsFactory creates a new instance of interceptedEquivalentProofsFactory -func NewInterceptedEquivalentProofsFactory(args ArgInterceptedDataFactory, proofsPool dataRetriever.ProofsPool) *interceptedEquivalentProofsFactory { +func NewInterceptedEquivalentProofsFactory(args ArgInterceptedEquivalentProofsFactory) *interceptedEquivalentProofsFactory { return &interceptedEquivalentProofsFactory{ marshaller: args.CoreComponents.InternalMarshalizer(), shardCoordinator: args.ShardCoordinator, headerSigVerifier: args.HeaderSigVerifier, - proofsPool: proofsPool, + proofsPool: args.ProofsPool, + headersPool: args.HeadersPool, + storage: args.Storage, + hasher: args.CoreComponents.Hasher(), } } @@ -34,6 +50,8 @@ func (factory *interceptedEquivalentProofsFactory) Create(buff []byte) (process. ShardCoordinator: factory.shardCoordinator, HeaderSigVerifier: factory.headerSigVerifier, Proofs: factory.proofsPool, + Headers: factory.headersPool, + Hasher: factory.hasher, } return interceptedBlocks.NewInterceptedEquivalentProof(args) } diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go index c96ade9528b..930a7bab7b4 100644 --- a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go @@ -5,20 +5,30 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/mock" processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) -func createMockArgInterceptedDataFactory() ArgInterceptedDataFactory { - return ArgInterceptedDataFactory{ - CoreComponents: &processMock.CoreComponentsMock{ - IntMarsh: &mock.MarshalizerMock{}, +func createMockArgInterceptedEquivalentProofsFactory() ArgInterceptedEquivalentProofsFactory { + return ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: ArgInterceptedDataFactory{ + CoreComponents: &processMock.CoreComponentsMock{ + IntMarsh: &mock.MarshalizerMock{}, + Hash: &hashingMocks.HasherMock{}, + }, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, }, - ShardCoordinator: &mock.ShardCoordinatorMock{}, - HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + HeadersPool: &pool.HeadersPoolStub{}, + Storage: &genericMocks.ChainStorerMock{}, } } @@ -28,22 +38,22 @@ func TestInterceptedEquivalentProofsFactory_IsInterfaceNil(t *testing.T) { var factory *interceptedEquivalentProofsFactory require.True(t, factory.IsInterfaceNil()) - factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) require.False(t, factory.IsInterfaceNil()) } func TestNewInterceptedEquivalentProofsFactory(t *testing.T) { t.Parallel() - factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) require.NotNil(t, factory) } func TestInterceptedEquivalentProofsFactory_Create(t *testing.T) { t.Parallel() - args := createMockArgInterceptedDataFactory() - factory := NewInterceptedEquivalentProofsFactory(args, &dataRetriever.ProofsPoolMock{}) + args := createMockArgInterceptedEquivalentProofsFactory() + factory := NewInterceptedEquivalentProofsFactory(args) require.NotNil(t, factory) providedProof := &block.HeaderProof{ diff --git a/process/interceptors/interceptedDataVerifier.go b/process/interceptors/interceptedDataVerifier.go index 0accf41d3fc..7f230b24ae2 100644 --- a/process/interceptors/interceptedDataVerifier.go +++ b/process/interceptors/interceptedDataVerifier.go @@ -11,10 +11,8 @@ import ( type interceptedDataStatus int8 const ( - validInterceptedData interceptedDataStatus = iota - invalidInterceptedData - - interceptedDataStatusBytesSize = 8 + validInterceptedData interceptedDataStatus = iota + interceptedDataStatusBytesSize = 8 ) type interceptedDataVerifier struct { @@ -56,7 +54,9 @@ func (idv *interceptedDataVerifier) Verify(interceptedData process.InterceptedDa err := interceptedData.CheckValidity() if err != nil { - idv.cache.Put(interceptedData.Hash(), invalidInterceptedData, interceptedDataStatusBytesSize) + log.Debug("Intercepted data is invalid", "hash", interceptedData.Hash(), "err", err) + // TODO: investigate to selectively add as invalid intercepted data only when data is indeed invalid instead of missing + // idv.cache.Put(interceptedData.Hash(), invalidInterceptedData, interceptedDataStatusBytesSize) return process.ErrInvalidInterceptedData } diff --git a/process/interceptors/interceptedDataVerifier_test.go b/process/interceptors/interceptedDataVerifier_test.go index 8913f5828d8..503eb790f32 100644 --- a/process/interceptors/interceptedDataVerifier_test.go +++ b/process/interceptors/interceptedDataVerifier_test.go @@ -117,7 +117,6 @@ func TestInterceptedDataVerifier_CheckValidityShouldNotWork(t *testing.T) { t.Parallel() checkValidityCounter := atomic.Counter{} - interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { checkValidityCounter.Add(1) @@ -151,7 +150,7 @@ func TestInterceptedDataVerifier_CheckValidityShouldNotWork(t *testing.T) { wg.Wait() require.Equal(t, int64(100), errCount.Get()) - require.Equal(t, int64(1), checkValidityCounter.Get()) + require.Equal(t, int64(101), checkValidityCounter.Get()) } func TestInterceptedDataVerifier_CheckExpiryTime(t *testing.T) { @@ -221,10 +220,10 @@ func TestInterceptedDataVerifier_CheckExpiryTime(t *testing.T) { require.Equal(t, process.ErrInvalidInterceptedData, err) require.Equal(t, int64(1), checkValidityCounter.Get()) - // Second retrieval should be from the cache. + // Second retrieval err = verifier.Verify(interceptedData) require.Equal(t, process.ErrInvalidInterceptedData, err) - require.Equal(t, int64(1), checkValidityCounter.Get()) + require.Equal(t, int64(2), checkValidityCounter.Get()) // Wait for the cache expiry <-time.After(expiryTestDuration + 100*time.Millisecond) @@ -232,6 +231,6 @@ func TestInterceptedDataVerifier_CheckExpiryTime(t *testing.T) { // Third retrieval should reach validity check again. err = verifier.Verify(interceptedData) require.Equal(t, process.ErrInvalidInterceptedData, err) - require.Equal(t, int64(2), checkValidityCounter.Get()) + require.Equal(t, int64(3), checkValidityCounter.Get()) }) } diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go index b11eca03aec..94c22b00544 100644 --- a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go @@ -5,14 +5,17 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/transaction" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) func createMockArgEquivalentProofsInterceptorProcessor() ArgEquivalentProofsInterceptorProcessor { @@ -105,6 +108,8 @@ func TestEquivalentProofsInterceptorProcessor_Save(t *testing.T) { ShardCoordinator: &mock.ShardCoordinatorMock{}, HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, Proofs: &dataRetriever.ProofsPoolMock{}, + Headers: &pool.HeadersPoolStub{}, + Hasher: &hashingMocks.HasherMock{}, } argInterceptedEquivalentProof.DataBuff, _ = argInterceptedEquivalentProof.Marshaller.Marshal(&block.HeaderProof{ PubKeysBitmap: []byte("bitmap"), diff --git a/process/interceptors/processor/hdrInterceptorProcessor.go b/process/interceptors/processor/hdrInterceptorProcessor.go index e60489c2ae5..c49f2bb9703 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor.go +++ b/process/interceptors/processor/hdrInterceptorProcessor.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -80,7 +81,7 @@ func (hip *HdrInterceptorProcessor) Save(data process.InterceptedData, _ core.Pe hip.headers.AddHeader(interceptedHdr.Hash(), interceptedHdr.HeaderHandler()) - if common.IsFlagEnabledAfterEpochsStartBlock(interceptedHdr.HeaderHandler(), hip.enableEpochsHandler, common.EquivalentMessagesFlag) { + if common.ShouldBlockHavePrevProof(interceptedHdr.HeaderHandler(), hip.enableEpochsHandler, common.EquivalentMessagesFlag) { err := hip.proofs.AddProof(interceptedHdr.HeaderHandler().GetPreviousProof()) if err != nil { log.Error("failed to add proof", "error", err, "intercepted header hash", interceptedHdr.Hash(), "header type", reflect.TypeOf(interceptedHdr.HeaderHandler())) diff --git a/process/interceptors/processor/hdrInterceptorProcessor_test.go b/process/interceptors/processor/hdrInterceptorProcessor_test.go index cc35b04d06b..6b611f3a1c5 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor_test.go +++ b/process/interceptors/processor/hdrInterceptorProcessor_test.go @@ -7,6 +7,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/processor" @@ -14,7 +16,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - "github.com/stretchr/testify/assert" ) func createMockHdrArgument() *processor.ArgHdrInterceptorProcessor { @@ -169,6 +170,7 @@ func TestHdrInterceptorProcessor_SaveNilDataShouldErr(t *testing.T) { func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { t.Parallel() + minNonceWithProof := uint64(2) hdrInterceptedData := &struct { testscommon.InterceptedDataStub mock.GetHdrHandlerStub @@ -180,7 +182,11 @@ func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { }, GetHdrHandlerStub: mock.GetHdrHandlerStub{ HeaderHandlerCalled: func() data.HeaderHandler { - return &testscommon.HeaderHandlerStub{} + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return minNonceWithProof + }, + } }, }, } diff --git a/process/interface.go b/process/interface.go index d7cbe87825b..f983e09e1cd 100644 --- a/process/interface.go +++ b/process/interface.go @@ -384,6 +384,7 @@ type ForkDetector interface { GetNotarizedHeaderHash(nonce uint64) []byte ResetProbableHighestNonce() SetFinalToLastCheckpoint() + ReceivedProof(proof data.HeaderProofHandler) IsInterfaceNil() bool } diff --git a/process/mock/forkDetectorMock.go b/process/mock/forkDetectorMock.go index 51e79af246f..b6345f1a7aa 100644 --- a/process/mock/forkDetectorMock.go +++ b/process/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,7 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) } // RestoreToGenesis - @@ -114,6 +116,13 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/process/peer/process.go b/process/peer/process.go index c5ebb890d8a..9146a7605b1 100644 --- a/process/peer/process.go +++ b/process/peer/process.go @@ -406,10 +406,7 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca log.Debug("UpdatePeerState - registering meta previous leader fees", "metaNonce", previousHeader.GetNonce()) bitmap := previousHeader.GetPubKeysBitmap() if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, previousHeader.GetEpoch()) { - proof := previousHeader.GetPreviousProof() - if !check.IfNilReflect(proof) { - bitmap = proof.GetPubKeysBitmap() - } + bitmap = vs.getBitmapForFullConsensus(previousHeader.GetShardID(), previousHeader.GetEpoch()) } err = vs.updateValidatorInfoOnSuccessfulBlock( leader, @@ -903,6 +900,17 @@ func (vs *validatorStatistics) RevertPeerState(header data.MetaHeaderHandler) er return vs.peerAdapter.RecreateTrie(rootHashHolder) } +// TODO: check if this can be taken from somewhere else +func (vs *validatorStatistics) getBitmapForFullConsensus(shardID uint32, epoch uint32) []byte { + consensusSize := vs.nodesCoordinator.ConsensusGroupSizeForShardAndEpoch(shardID, epoch) + bitmap := make([]byte, consensusSize/8+1) + for i := 0; i < consensusSize; i++ { + bitmap[i/8] |= 1 << (uint16(i) % 8) + } + + return bitmap +} + func (vs *validatorStatistics) updateShardDataPeerState( header data.HeaderHandler, cacheMap map[string]data.HeaderHandler, @@ -935,10 +943,14 @@ func (vs *validatorStatistics) updateShardDataPeerState( } log.Debug("updateShardDataPeerState - registering shard leader fees", "shard headerHash", h.HeaderHash, "accumulatedFees", h.AccumulatedFees.String(), "developerFees", h.DeveloperFees.String()) + bitmap := h.PubKeysBitmap + if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, h.Epoch) { + bitmap = vs.getBitmapForFullConsensus(h.ShardID, h.Epoch) + } shardInfoErr = vs.updateValidatorInfoOnSuccessfulBlock( leader, shardConsensus, - h.PubKeysBitmap, + bitmap, big.NewInt(0).Sub(h.AccumulatedFees, h.DeveloperFees), h.ShardID, ) diff --git a/process/sync/baseForkDetector.go b/process/sync/baseForkDetector.go index db5a601524a..cfe0a675569 100644 --- a/process/sync/baseForkDetector.go +++ b/process/sync/baseForkDetector.go @@ -7,16 +7,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) type headerInfo struct { - epoch uint32 - nonce uint64 - round uint64 - hash []byte - state process.BlockHeaderState + epoch uint32 + nonce uint64 + round uint64 + hash []byte + state process.BlockHeaderState + hasProof bool } type checkpointInfo struct { @@ -43,14 +46,16 @@ type baseForkDetector struct { fork forkInfo mutFork sync.RWMutex - blackListHandler process.TimeCacher - genesisTime int64 - blockTracker process.BlockTracker - forkDetector forkDetector - genesisNonce uint64 - genesisRound uint64 - maxForkHeaderEpoch uint32 - genesisEpoch uint32 + blackListHandler process.TimeCacher + genesisTime int64 + blockTracker process.BlockTracker + forkDetector forkDetector + genesisNonce uint64 + genesisRound uint64 + maxForkHeaderEpoch uint32 + genesisEpoch uint32 + enableEpochsHandler common.EnableEpochsHandler + proofsPool process.ProofsPool } // SetRollBackNonce sets the nonce where the chain should roll back @@ -102,7 +107,7 @@ func (bfd *baseForkDetector) checkBlockBasicValidity( roundDif := int64(header.GetRound()) - int64(bfd.finalCheckpoint().round) nonceDif := int64(header.GetNonce()) - int64(bfd.finalCheckpoint().nonce) - //TODO: Analyze if the acceptance of some headers which came for the next round could generate some attack vectors + // TODO: Analyze if the acceptance of some headers which came for the next round could generate some attack vectors nextRound := bfd.roundHandler.Index() + 1 genesisTimeFromHeader := bfd.computeGenesisTimeFromHeader(header) @@ -111,7 +116,7 @@ func (bfd *baseForkDetector) checkBlockBasicValidity( process.AddHeaderToBlackList(bfd.blackListHandler, headerHash) return process.ErrHeaderIsBlackListed } - //TODO: This check could be removed when this protection mechanism would be implemented on interceptors side + // TODO: This check could be removed when this protection mechanism would be implemented on interceptors side if genesisTimeFromHeader != bfd.genesisTime { process.AddHeaderToBlackList(bfd.blackListHandler, headerHash) return ErrGenesisTimeMissmatch @@ -197,11 +202,17 @@ func (bfd *baseForkDetector) computeProbableHighestNonce() uint64 { probableHighestNonce := bfd.finalCheckpoint().nonce bfd.mutHeaders.RLock() - for nonce := range bfd.headers { + for nonce, headers := range bfd.headers { if nonce <= probableHighestNonce { continue } - probableHighestNonce = nonce + + for _, hInfo := range headers { + if hInfo.hasProof { + probableHighestNonce = nonce + break + } + } } bfd.mutHeaders.RUnlock() @@ -286,8 +297,10 @@ func (bfd *baseForkDetector) append(hdrInfo *headerInfo) bool { return true } + bfd.adjustHeadersWithInfo(hdrInfo) + for _, hdrInfoStored := range hdrInfos { - if bytes.Equal(hdrInfoStored.hash, hdrInfo.hash) && hdrInfoStored.state == hdrInfo.state { + if bytes.Equal(hdrInfoStored.hash, hdrInfo.hash) && hdrInfoStored.state == hdrInfo.state && hdrInfoStored.hasProof == hdrInfo.hasProof { return false } } @@ -296,6 +309,23 @@ func (bfd *baseForkDetector) append(hdrInfo *headerInfo) bool { return true } +func (bfd *baseForkDetector) adjustHeadersWithInfo(hInfo *headerInfo) { + if !bfd.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hInfo.epoch) { + return + } + + if !hInfo.hasProof { + return + } + + hdrInfos := bfd.headers[hInfo.nonce] + for i := range hdrInfos { + if bytes.Equal(hdrInfos[i].hash, hInfo.hash) { + hdrInfos[i].hasProof = true + } + } +} + // GetHighestFinalBlockNonce gets the highest nonce of the block which is final, and it can not be reverted anymore func (bfd *baseForkDetector) GetHighestFinalBlockNonce() uint64 { return bfd.finalCheckpoint().nonce @@ -682,6 +712,27 @@ func (bfd *baseForkDetector) addHeader( return nil } +// ReceivedProof is called when a proof is received +func (bfd *baseForkDetector) ReceivedProof(proof data.HeaderProofHandler) { + bfd.processReceivedProof(proof) +} + +func (bfd *baseForkDetector) processReceivedProof(proof data.HeaderProofHandler) { + bfd.setHighestNonceReceived(proof.GetHeaderNonce()) + + hInfo := &headerInfo{ + epoch: proof.GetHeaderEpoch(), + nonce: proof.GetHeaderNonce(), + round: proof.GetHeaderRound(), + hash: proof.GetHeaderHash(), + state: process.BHReceived, + hasProof: true, + } + + _ = bfd.append(hInfo) + +} + func (bfd *baseForkDetector) processReceivedBlock( header data.HeaderHandler, headerHash []byte, @@ -690,9 +741,13 @@ func (bfd *baseForkDetector) processReceivedBlock( selfNotarizedHeadersHashes [][]byte, doJobOnBHProcessed func(data.HeaderHandler, []byte, []data.HeaderHandler, [][]byte), ) { + hasProof := true // old blocks have consensus proof on them + if bfd.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + hasProof = bfd.proofsPool.HasProof(header.GetShardID(), headerHash) + } bfd.setHighestNonceReceived(header.GetNonce()) - if state == process.BHProposed { + if state == process.BHProposed || !hasProof { return } @@ -701,14 +756,16 @@ func (bfd *baseForkDetector) processReceivedBlock( state = process.BHReceivedTooLate } - appended := bfd.append(&headerInfo{ - epoch: header.GetEpoch(), - nonce: header.GetNonce(), - round: header.GetRound(), - hash: headerHash, - state: state, - }) - if !appended { + hInfo := &headerInfo{ + epoch: header.GetEpoch(), + nonce: header.GetNonce(), + round: header.GetRound(), + hash: headerHash, + state: state, + hasProof: hasProof, + } + + if !bfd.append(hInfo) { return } @@ -719,14 +776,15 @@ func (bfd *baseForkDetector) processReceivedBlock( probableHighestNonce := bfd.computeProbableHighestNonce() bfd.setProbableHighestNonce(probableHighestNonce) - log.Debug("forkDetector.AddHeader", - "round", header.GetRound(), - "nonce", header.GetNonce(), - "hash", headerHash, - "state", state, + log.Debug("forkDetector.appendHeaderInfo", + "round", hInfo.round, + "nonce", hInfo.nonce, + "hash", hInfo.hash, + "state", hInfo.state, "probable highest nonce", bfd.probableHighestNonce(), "last checkpoint nonce", bfd.lastCheckpoint().nonce, - "final checkpoint nonce", bfd.finalCheckpoint().nonce) + "final checkpoint nonce", bfd.finalCheckpoint().nonce, + "has proof", hInfo.hasProof) } // SetFinalToLastCheckpoint sets the final checkpoint to the last checkpoint added in list diff --git a/process/sync/baseForkDetector_test.go b/process/sync/baseForkDetector_test.go index 10f857bfbce..50cf509c47b 100644 --- a/process/sync/baseForkDetector_test.go +++ b/process/sync/baseForkDetector_test.go @@ -8,10 +8,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -23,6 +28,8 @@ func TestNewBasicForkDetector_ShouldErrNilRoundHandler(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilRoundHandler, err) assert.Nil(t, bfd) @@ -37,6 +44,8 @@ func TestNewBasicForkDetector_ShouldErrNilBlackListHandler(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilBlackListCacher, err) assert.Nil(t, bfd) @@ -51,11 +60,45 @@ func TestNewBasicForkDetector_ShouldErrNilBlockTracker(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilBlockTracker, err) assert.Nil(t, bfd) } +func TestNewBasicForkDetector_ShouldErrNilEnableEpochsHandler(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} + bfd, err := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetriever.ProofsPoolMock{}, + ) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) + assert.Nil(t, bfd) +} + +func TestNewBasicForkDetector_ShouldErrNilProofsPool(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} + bfd, err := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.Equal(t, process.ErrNilProofsPool, err) + assert.Nil(t, bfd) +} + func TestNewBasicForkDetector_ShouldWork(t *testing.T) { t.Parallel() @@ -65,6 +108,8 @@ func TestNewBasicForkDetector_ShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Nil(t, err) assert.NotNil(t, bfd) @@ -84,6 +129,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrGenesisTimeMissmatch(t *te &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, genesisTime, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: round, TimeStamp: incorrectTimeStamp}, []byte("hash")) @@ -102,6 +149,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrLowerRoundInBlock(t *testi &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) bfd.SetFinalCheckpoint(1, 1, nil) err := bfd.CheckBlockValidity(&block.Header{PubKeysBitmap: []byte("X")}, []byte("hash")) @@ -117,6 +166,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrLowerNonceInBlock(t *testi &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) bfd.SetFinalCheckpoint(2, 2, nil) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 3, PubKeysBitmap: []byte("X")}, []byte("hash")) @@ -132,6 +183,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrHigherRoundInBlock(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 2, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Equal(t, sync.ErrHigherRoundInBlock, err) @@ -146,6 +199,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrHigherNonceInBlock(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 2, Round: 1, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Equal(t, sync.ErrHigherNonceInBlock, err) @@ -160,6 +215,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Nil(t, err) @@ -178,6 +235,8 @@ func TestBasicForkDetector_RemoveHeadersShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1 @@ -209,6 +268,8 @@ func TestBasicForkDetector_CheckForkOnlyOneShardHeaderOnANonceShouldReturnFalse( &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -237,6 +298,8 @@ func TestBasicForkDetector_CheckForkOnlyReceivedHeadersShouldReturnFalse(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -267,6 +330,8 @@ func TestBasicForkDetector_CheckForkOnlyOneShardHeaderOnANonceReceivedAndProcess &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -297,6 +362,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalse(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.MetaBlock{Nonce: 1, Round: 3, PubKeysBitmap: []byte("X")}, @@ -325,6 +392,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalseWhenLowe &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -369,6 +438,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalseWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -412,6 +483,8 @@ func TestBasicForkDetector_CheckForkShardHeaderProcessedShouldReturnTrueWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: 4, PubKeysBitmap: []byte("X")} @@ -476,6 +549,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnTrueWhenEqual &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -518,6 +593,8 @@ func TestBasicForkDetector_CheckForkShardHeaderProcessedShouldReturnTrueWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: 4, PubKeysBitmap: []byte("X")} @@ -581,6 +658,8 @@ func TestBasicForkDetector_CheckForkShouldReturnTrue(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 4 _ = bfd.AddHeader( @@ -625,6 +704,8 @@ func TestBasicForkDetector_CheckForkShouldReturnFalseWhenForkIsOnFinalCheckpoint &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1 _ = bfd.AddHeader( @@ -661,6 +742,8 @@ func TestBasicForkDetector_CheckForkShouldReturnFalseWhenForkIsOnHigherEpochBloc &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 2 _ = bfd.AddHeader( @@ -703,6 +786,8 @@ func TestBasicForkDetector_RemovePastHeadersShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHReceived, nil, nil) _ = bfd.AddHeader(hdr2, hash2, process.BHReceived, nil, nil) @@ -737,6 +822,8 @@ func TestBasicForkDetector_RemoveInvalidReceivedHeadersShouldWork(t *testing.T) &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 _ = bfd.AddHeader(hdr0, hash0, process.BHReceived, nil, nil) @@ -775,6 +862,8 @@ func TestBasicForkDetector_RemoveCheckpointHeaderNonceShouldResetCheckpoint(t *t &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) @@ -794,6 +883,8 @@ func TestBasicForkDetector_GetHighestFinalBlockNonce(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.MetaBlock{Nonce: 2, Round: 1, PubKeysBitmap: []byte("X")} @@ -821,6 +912,7 @@ func TestBasicForkDetector_GetHighestFinalBlockNonce(t *testing.T) { assert.Equal(t, uint64(3), bfd.GetHighestFinalBlockNonce()) } +// TODO: add specific tests for equivalent proofs func TestBasicForkDetector_ProbableHighestNonce(t *testing.T) { t.Parallel() @@ -830,6 +922,12 @@ func TestBasicForkDetector_ProbableHighestNonce(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag != common.EquivalentMessagesFlag + }, + }, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 @@ -882,7 +980,14 @@ func TestShardForkDetector_ShouldAddBlockInForkDetectorShouldWork(t *testing.T) t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.Header{Nonce: 1, Round: 1} receivedTooLate := sfd.IsHeaderReceivedTooLate(hdr, process.BHProcessed, process.BlockFinality) @@ -900,7 +1005,14 @@ func TestShardForkDetector_ShouldAddBlockInForkDetectorShouldErrLowerRoundInBloc t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.Header{Nonce: 1, Round: 1} hdr.Round = uint64(roundHandlerMock.RoundIndex - process.BlockFinality - 1) @@ -912,7 +1024,14 @@ func TestMetaForkDetector_ShouldAddBlockInForkDetectorShouldWork(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - mfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + mfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.MetaBlock{Nonce: 1, Round: 1} receivedTooLate := mfd.IsHeaderReceivedTooLate(hdr, process.BHProcessed, process.BlockFinality) @@ -930,7 +1049,14 @@ func TestMetaForkDetector_ShouldAddBlockInForkDetectorShouldErrLowerRoundInBlock t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - mfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + mfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.MetaBlock{Nonce: 1, Round: 1} hdr.Round = uint64(roundHandlerMock.RoundIndex - process.BlockFinality - 1) @@ -942,7 +1068,14 @@ func TestShardForkDetector_AddNotarizedHeadersShouldNotChangeTheFinalCheckpoint( t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr1 := &block.Header{Nonce: 3, Round: 3} hash1 := []byte("hash1") hdr2 := &block.Header{Nonce: 4, Round: 4} @@ -988,7 +1121,14 @@ func TestBaseForkDetector_IsConsensusStuckNotSyncingShouldReturnFalse(t *testing t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{} - bfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) bfd.SetProbableHighestNonce(1) @@ -1004,6 +1144,8 @@ func TestBaseForkDetector_IsConsensusStuckNoncesDifferencesNotEnoughShouldReturn &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 10 @@ -1019,6 +1161,8 @@ func TestBaseForkDetector_IsConsensusStuckNotInProperRoundShouldReturnFalse(t *t &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 @@ -1034,6 +1178,8 @@ func TestBaseForkDetector_IsConsensusStuckShouldReturnTrue(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) // last checkpoint will be (round = 0 , nonce = 0) @@ -1060,6 +1206,8 @@ func TestBaseForkDetector_ComputeTimeDuration(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, genesisTime, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: hdrRound, PubKeysBitmap: []byte("X"), TimeStamp: hdrTimeStamp} @@ -1073,7 +1221,14 @@ func TestShardForkDetector_RemoveHeaderShouldComputeFinalCheckpoint(t *testing.T t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr1 := &block.Header{Nonce: 3, Round: 3} hash1 := []byte("hash1") hdr2 := &block.Header{Nonce: 4, Round: 4} @@ -1114,6 +1269,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldWorkOnEqualRoundWit &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -1162,6 +1319,8 @@ func TestBasicForkDetector_SetFinalToLastCheckpointShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1000 diff --git a/process/sync/baseSync.go b/process/sync/baseSync.go index cf13638912f..4c291ba349b 100644 --- a/process/sync/baseSync.go +++ b/process/sync/baseSync.go @@ -18,6 +18,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -30,7 +32,6 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/trie/storageMarker" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process/sync") @@ -161,6 +162,14 @@ func (boot *baseBootstrap) requestedHeaderHash() []byte { return boot.headerhash } +func (boot *baseBootstrap) processReceivedProof(headerProof data.HeaderProofHandler) { + if boot.shardCoordinator.SelfId() != headerProof.GetHeaderShardId() { + return + } + + boot.forkDetector.ReceivedProof(headerProof) +} + func (boot *baseBootstrap) processReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) { if boot.shardCoordinator.SelfId() != headerHandler.GetShardID() { return @@ -711,17 +720,6 @@ func (boot *baseBootstrap) handleEquivalentProof( return nil } - prevHeader, err := boot.blockBootstrapper.getHeaderWithHashRequestingIfMissing(header.GetPrevHash()) - if err != nil { - return err - } - - if !boot.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, prevHeader.GetEpoch()) { - // no need to check proof for first block after activation - log.Info("handleEquivalentProof: no need to check equivalent proof for first activation block") - return nil - } - // process block only if there is a proof for it hasProof := boot.proofs.HasProof(header.GetShardID(), headerHash) if hasProof { @@ -730,7 +728,7 @@ func (boot *baseBootstrap) handleEquivalentProof( log.Trace("baseBootstrap.handleEquivalentProof: did not have proof for header, will try again", "headerHash", headerHash) - _, _, err = boot.blockBootstrapper.getHeaderWithNonceRequestingIfMissing(header.GetNonce() + 1) + _, _, err := boot.blockBootstrapper.getHeaderWithNonceRequestingIfMissing(header.GetNonce() + 1) if err != nil { return err } @@ -1220,6 +1218,7 @@ func (boot *baseBootstrap) init() { boot.poolsHolder.MiniBlocks().RegisterHandler(boot.receivedMiniblock, core.UniqueIdentifier()) boot.headers.RegisterHandler(boot.processReceivedHeader) + boot.proofs.RegisterHandler(boot.processReceivedProof) boot.syncStateListeners = make([]func(bool), 0) boot.requestedHashes = process.RequiredDataPool{} diff --git a/process/sync/metaForkDetector.go b/process/sync/metaForkDetector.go index 178e4e96042..f6a285fb3bc 100644 --- a/process/sync/metaForkDetector.go +++ b/process/sync/metaForkDetector.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -23,6 +25,8 @@ func NewMetaForkDetector( blackListHandler process.TimeCacher, blockTracker process.BlockTracker, genesisTime int64, + enableEpochsHandler common.EnableEpochsHandler, + proofsPool process.ProofsPool, ) (*metaForkDetector, error) { if check.IfNil(roundHandler) { @@ -34,6 +38,12 @@ func NewMetaForkDetector( if check.IfNil(blockTracker) { return nil, process.ErrNilBlockTracker } + if check.IfNil(enableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } + if check.IfNil(proofsPool) { + return nil, process.ErrNilProofsPool + } genesisHdr, _, err := blockTracker.GetSelfNotarizedHeader(core.MetachainShardId, 0) if err != nil { @@ -41,13 +51,15 @@ func NewMetaForkDetector( } bfd := &baseForkDetector{ - roundHandler: roundHandler, - blackListHandler: blackListHandler, - genesisTime: genesisTime, - blockTracker: blockTracker, - genesisNonce: genesisHdr.GetNonce(), - genesisRound: genesisHdr.GetRound(), - genesisEpoch: genesisHdr.GetEpoch(), + roundHandler: roundHandler, + blackListHandler: blackListHandler, + genesisTime: genesisTime, + blockTracker: blockTracker, + genesisNonce: genesisHdr.GetNonce(), + genesisRound: genesisHdr.GetRound(), + genesisEpoch: genesisHdr.GetEpoch(), + enableEpochsHandler: enableEpochsHandler, + proofsPool: proofsPool, } bfd.headers = make(map[uint64][]*headerInfo) diff --git a/process/sync/metaForkDetector_test.go b/process/sync/metaForkDetector_test.go index 5db5855c6a4..05e5a98481f 100644 --- a/process/sync/metaForkDetector_test.go +++ b/process/sync/metaForkDetector_test.go @@ -5,10 +5,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -20,6 +24,8 @@ func TestNewMetaForkDetector_NilRoundHandlerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilRoundHandler, err) @@ -33,6 +39,8 @@ func TestNewMetaForkDetector_NilBlackListShouldErr(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlackListCacher, err) @@ -46,11 +54,43 @@ func TestNewMetaForkDetector_NilBlockTrackerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlockTracker, err) } +func TestNewMetaForkDetector_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewMetaForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetriever.ProofsPoolMock{}, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + +func TestNewMetaForkDetector_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewMetaForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaForkDetector_OkParamsShouldWork(t *testing.T) { t.Parallel() @@ -59,6 +99,8 @@ func TestNewMetaForkDetector_OkParamsShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Nil(t, err) assert.False(t, check.IfNil(sfd)) @@ -73,7 +115,14 @@ func TestMetaForkDetector_AddHeaderNilHeaderShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(nil, make([]byte, 0), process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHeader, err) } @@ -82,7 +131,14 @@ func TestMetaForkDetector_AddHeaderNilHashShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(&block.Header{}, nil, process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHash, err) } @@ -93,7 +149,14 @@ func TestMetaForkDetector_AddHeaderNotPresentShouldWork(t *testing.T) { hdr := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} hash := make([]byte, 0) roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(hdr, hash, process.BHProcessed, nil, nil) assert.Nil(t, err) @@ -111,7 +174,14 @@ func TestMetaForkDetector_AddHeaderPresentShouldAppend(t *testing.T) { hdr2 := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} hash2 := []byte("hash2") roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) err := bfd.AddHeader(hdr2, hash2, process.BHProcessed, nil, nil) @@ -129,7 +199,14 @@ func TestMetaForkDetector_AddHeaderWithProcessedBlockShouldSetCheckpoint(t *test hdr1 := &block.Header{Nonce: 69, Round: 72, PubKeysBitmap: []byte("X")} hash1 := []byte("hash1") roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 73} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) assert.Equal(t, hdr1.Nonce, bfd.LastCheckpointNonce()) } @@ -141,7 +218,14 @@ func TestMetaForkDetector_AddHeaderPresentShouldNotRewriteState(t *testing.T) { hash := []byte("hash1") hdr2 := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash, process.BHReceived, nil, nil) err := bfd.AddHeader(hdr2, hash, process.BHProcessed, nil, nil) @@ -158,7 +242,14 @@ func TestMetaForkDetector_AddHeaderHigherNonceThanRoundShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader( &block.Header{Nonce: 1, Round: 0, PubKeysBitmap: []byte("X")}, []byte("hash1"), diff --git a/process/sync/metablock_test.go b/process/sync/metablock_test.go index 73386a021f1..1f5230b3f6e 100644 --- a/process/sync/metablock_test.go +++ b/process/sync/metablock_test.go @@ -968,6 +968,8 @@ func TestMetaBootstrap_GetNodeStateShouldReturnNotSynchronizedWhenForkIsDetected &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewMetaBootstrap(args) @@ -1033,6 +1035,8 @@ func TestMetaBootstrap_GetNodeStateShouldReturnSynchronizedWhenForkIsDetectedAnd &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewMetaBootstrap(args) @@ -1887,7 +1891,7 @@ func TestMetaBootstrap_HandleEquivalentProof(t *testing.T) { require.Nil(t, err) }) - t.Run("should return nil if first block after activation", func(t *testing.T) { + t.Run("should fail if first block after activation and no proof for it", func(t *testing.T) { t.Parallel() prevHeader := &block.MetaBlock{ @@ -1932,7 +1936,7 @@ func TestMetaBootstrap_HandleEquivalentProof(t *testing.T) { require.Nil(t, err) err = bs.HandleEquivalentProof(header, headerHash1) - require.Nil(t, err) + require.Error(t, err) }) t.Run("should work, proof already in pool", func(t *testing.T) { diff --git a/process/sync/shardForkDetector.go b/process/sync/shardForkDetector.go index 52715f36163..d688c548a2e 100644 --- a/process/sync/shardForkDetector.go +++ b/process/sync/shardForkDetector.go @@ -7,6 +7,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -24,6 +26,8 @@ func NewShardForkDetector( blackListHandler process.TimeCacher, blockTracker process.BlockTracker, genesisTime int64, + enableEpochsHandler common.EnableEpochsHandler, + proofsPool process.ProofsPool, ) (*shardForkDetector, error) { if check.IfNil(roundHandler) { @@ -35,6 +39,12 @@ func NewShardForkDetector( if check.IfNil(blockTracker) { return nil, process.ErrNilBlockTracker } + if check.IfNil(enableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } + if check.IfNil(proofsPool) { + return nil, process.ErrNilProofsPool + } genesisHdr, _, err := blockTracker.GetSelfNotarizedHeader(core.MetachainShardId, 0) if err != nil { @@ -42,13 +52,15 @@ func NewShardForkDetector( } bfd := &baseForkDetector{ - roundHandler: roundHandler, - blackListHandler: blackListHandler, - genesisTime: genesisTime, - blockTracker: blockTracker, - genesisNonce: genesisHdr.GetNonce(), - genesisRound: genesisHdr.GetRound(), - genesisEpoch: genesisHdr.GetEpoch(), + roundHandler: roundHandler, + blackListHandler: blackListHandler, + genesisTime: genesisTime, + blockTracker: blockTracker, + genesisNonce: genesisHdr.GetNonce(), + genesisRound: genesisHdr.GetRound(), + genesisEpoch: genesisHdr.GetEpoch(), + enableEpochsHandler: enableEpochsHandler, + proofsPool: proofsPool, } bfd.headers = make(map[uint64][]*headerInfo) @@ -136,11 +148,13 @@ func (sfd *shardForkDetector) appendSelfNotarizedHeaders( continue } + hasProof := sfd.proofsPool.HasProof(selfNotarizedHeaders[i].GetShardID(), selfNotarizedHeadersHashes[i]) appended := sfd.append(&headerInfo{ - nonce: selfNotarizedHeaders[i].GetNonce(), - round: selfNotarizedHeaders[i].GetRound(), - hash: selfNotarizedHeadersHashes[i], - state: process.BHNotarized, + nonce: selfNotarizedHeaders[i].GetNonce(), + round: selfNotarizedHeaders[i].GetRound(), + hash: selfNotarizedHeadersHashes[i], + state: process.BHNotarized, + hasProof: hasProof, }) if appended { log.Debug("added self notarized header in fork detector", diff --git a/process/sync/shardForkDetector_test.go b/process/sync/shardForkDetector_test.go index 98412430e71..d3b37a0dfd1 100644 --- a/process/sync/shardForkDetector_test.go +++ b/process/sync/shardForkDetector_test.go @@ -5,10 +5,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -20,6 +24,8 @@ func TestNewShardForkDetector_NilRoundHandlerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilRoundHandler, err) @@ -33,6 +39,8 @@ func TestNewShardForkDetector_NilBlackListShouldErr(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlackListCacher, err) @@ -46,11 +54,43 @@ func TestNewShardForkDetector_NilBlockTrackerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlockTracker, err) } +func TestNewShardForkDetector_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewShardForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetrieverMock.ProofsPoolMock{}, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + +func TestNewShardForkDetector_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewShardForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardForkDetector_OkParamsShouldWork(t *testing.T) { t.Parallel() @@ -59,6 +99,8 @@ func TestNewShardForkDetector_OkParamsShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.Nil(t, err) assert.False(t, check.IfNil(sfd)) @@ -78,6 +120,8 @@ func TestShardForkDetector_AddHeaderNilHeaderShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(nil, make([]byte, 0), process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHeader, err) @@ -92,6 +136,8 @@ func TestShardForkDetector_AddHeaderNilHashShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(&block.Header{}, nil, process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHash, err) @@ -108,6 +154,8 @@ func TestShardForkDetector_AddHeaderNotPresentShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(hdr, hash, process.BHProcessed, nil, nil) assert.Nil(t, err) @@ -130,6 +178,8 @@ func TestShardForkDetector_AddHeaderPresentShouldAppend(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) err := bfd.AddHeader(hdr2, hash2, process.BHProcessed, nil, nil) @@ -152,6 +202,8 @@ func TestShardForkDetector_AddHeaderWithProcessedBlockShouldSetCheckpoint(t *tes &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) assert.Equal(t, hdr1.Nonce, bfd.LastCheckpointNonce()) @@ -169,6 +221,8 @@ func TestShardForkDetector_AddHeaderPresentShouldNotRewriteState(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash, process.BHReceived, nil, nil) err := bfd.AddHeader(hdr2, hash, process.BHProcessed, nil, nil) @@ -190,6 +244,8 @@ func TestShardForkDetector_AddHeaderHigherNonceThanRoundShouldErr(t *testing.T) &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader( &block.Header{Nonce: 1, Round: 0, PubKeysBitmap: []byte("X")}, []byte("hash1"), process.BHProcessed, nil, nil) diff --git a/process/sync/shardblock_test.go b/process/sync/shardblock_test.go index fbf974c1ee4..e33a2b83cd9 100644 --- a/process/sync/shardblock_test.go +++ b/process/sync/shardblock_test.go @@ -1168,6 +1168,8 @@ func TestBootstrap_GetNodeStateShouldReturnNotSynchronizedWhenForkIsDetectedAndI &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewShardBootstrap(args) @@ -1243,6 +1245,8 @@ func TestBootstrap_GetNodeStateShouldReturnSynchronizedWhenForkIsDetectedAndItRe &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewShardBootstrap(args) @@ -2229,3 +2233,333 @@ func TestShardBootstrap_NilInnerBootstrapperClose(t *testing.T) { bootstrapper := &sync.ShardBootstrap{} assert.Nil(t, bootstrapper.Close()) } + +func TestShardBootstrap_HandleEquivalentProof(t *testing.T) { + t.Parallel() + + prevHeaderHash1 := []byte("prevHeaderHash") + headerHash1 := []byte("headerHash") + + t.Run("flag not activated, should return direclty", func(t *testing.T) { + t.Parallel() + + header := &block.Header{ + Nonce: 11, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Nil(t, err) + }) + + t.Run("should fail if first block after activation and no proof for it", func(t *testing.T) { + t.Parallel() + + prevHeader := &block.Header{ + Epoch: 3, + Nonce: 10, + } + + header := &block.Header{ + Epoch: 4, + Nonce: 11, + PrevHash: prevHeaderHash1, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + if epoch == 4 { + return flag == common.EquivalentMessagesFlag + } + + return false + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, prevHeaderHash1) { + return prevHeader, nil + } + + return nil, sync.ErrHeaderNotFound + } + + return sds + } + + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Error(t, err) + }) + + t.Run("should work, proof already in pool", func(t *testing.T) { + t.Parallel() + + prevHeader := &block.Header{ + Nonce: 10, + } + + header := &block.Header{ + Nonce: 11, + PrevHash: prevHeaderHash1, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, prevHeaderHash1) { + return prevHeader, nil + } + + return nil, sync.ErrHeaderNotFound + } + + return sds + } + + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Nil(t, err) + }) + + t.Run("should work, by checking for next header", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.Header{ + Nonce: 10, + } + + header2 := &block.Header{ + Nonce: 11, + PrevHash: headerHash1, + } + + header3 := &block.Header{ + Nonce: 12, + PrevHash: headerHash2, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + if hdrNonce == header2.GetNonce()+1 { + return []data.HeaderHandler{header3}, [][]byte{headerHash2}, nil + } + + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled == 0 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Nil(t, err) + }) + + t.Run("should return err if failing to get proof after second request", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.Header{ + Nonce: 10, + } + + header2 := &block.Header{ + Nonce: 11, + PrevHash: headerHash1, + } + + header3 := &block.Header{ + Nonce: 12, + PrevHash: headerHash2, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + if hdrNonce == header2.GetNonce()+1 { + return []data.HeaderHandler{header3}, [][]byte{headerHash2}, nil + } + + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled < 2 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Error(t, err) + }) + + t.Run("should return err if failing to request next header", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.Header{ + Nonce: 10, + } + + header2 := &block.Header{ + Nonce: 11, + PrevHash: headerHash1, + } + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled < 2 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Error(t, err) + }) +} diff --git a/process/track/baseBlockTrack_test.go b/process/track/baseBlockTrack_test.go index b32b943faf9..aeb6e0d0e59 100644 --- a/process/track/baseBlockTrack_test.go +++ b/process/track/baseBlockTrack_test.go @@ -10,6 +10,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" processBlock "github.com/multiversx/mx-chain-go/process/block" @@ -24,9 +28,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const maxGasLimitPerBlock = uint64(1500000000) @@ -107,8 +108,9 @@ func CreateShardTrackerMockArguments() track.ArgShardTracker { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) whitelistHandler := &testscommon.WhiteListHandlerStub{} @@ -147,8 +149,9 @@ func CreateMetaTrackerMockArguments() track.ArgMetaTracker { shardCoordinatorMock.CurrentShard = core.MetachainShardId genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) whitelistHandler := &testscommon.WhiteListHandlerStub{} @@ -186,8 +189,9 @@ func CreateBaseTrackerMockArguments() track.ArgBaseTracker { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) feeHandler := &economicsmocks.EconomicsHandlerStub{ @@ -2257,7 +2261,7 @@ func TestComputeLongestChain_ShouldWorkWithLongestChain(t *testing.T) { assert.Equal(t, longestChain+chains-1, uint64(len(headers))) } -//------- CheckBlockAgainstRoundHandler +// ------- CheckBlockAgainstRoundHandler func TestBaseBlockTrack_CheckBlockAgainstRoundHandlerNilHeaderShouldErr(t *testing.T) { t.Parallel() @@ -2306,7 +2310,7 @@ func TestBaseBlockTrack_CheckBlockAgainstRoundHandlerShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- CheckBlockAgainstFinal +// ------- CheckBlockAgainstFinal func TestBaseBlockTrack_CheckBlockAgainstFinalNilHeaderShouldErr(t *testing.T) { t.Parallel() diff --git a/process/track/blockProcessor.go b/process/track/blockProcessor.go index 11b1d9aef3f..72fd9993283 100644 --- a/process/track/blockProcessor.go +++ b/process/track/blockProcessor.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" @@ -309,13 +310,13 @@ func (bp *blockProcessor) checkHeaderFinality( sortedHeadersHashes [][]byte, index int, ) error { - if check.IfNil(header) { return process.ErrNilBlockHeader } if common.IsFlagEnabledAfterEpochsStartBlock(header, bp.enableEpochsHandler, common.EquivalentMessagesFlag) { - if bp.proofsPool.HasProof(header.GetShardID(), sortedHeadersHashes[index]) { + // the index in argument is for the next block after header + if bp.proofsPool.HasProof(header.GetShardID(), sortedHeadersHashes[index-1]) { return nil } @@ -324,7 +325,6 @@ func (bp *blockProcessor) checkHeaderFinality( prevHeader := header numFinalityAttestingHeaders := uint64(0) - for i := index; i < len(sortedHeaders); i++ { currHeader := sortedHeaders[i] if numFinalityAttestingHeaders >= bp.blockFinality || currHeader.GetNonce() > prevHeader.GetNonce()+1 { diff --git a/process/track/blockProcessor_test.go b/process/track/blockProcessor_test.go index 05d6275047f..04e1849e75b 100644 --- a/process/track/blockProcessor_test.go +++ b/process/track/blockProcessor_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -14,20 +15,22 @@ import ( "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" processBlock "github.com/multiversx/mx-chain-go/process/block" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/track" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func CreateBlockProcessorMockArguments() track.ArgBlockProcessor { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) diff --git a/sharding/chainParametersHolder.go b/sharding/chainParametersHolder.go index 982d41679d7..341460d2dbd 100644 --- a/sharding/chainParametersHolder.go +++ b/sharding/chainParametersHolder.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" ) diff --git a/testscommon/consensus/consensusDataContainerMock.go b/testscommon/consensus/consensusDataContainerMock.go index ad00574ca6b..4c32064e14c 100644 --- a/testscommon/consensus/consensusDataContainerMock.go +++ b/testscommon/consensus/consensusDataContainerMock.go @@ -15,8 +15,8 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) -// TODO: remove this mock component; implement setters for main component in export_test.go // ConsensusCoreMock - +// TODO: remove this mock component; implement setters for main component in export_test.go type ConsensusCoreMock struct { blockChain data.ChainHandler blockProcessor process.BlockProcessor @@ -43,6 +43,7 @@ type ConsensusCoreMock struct { signingHandler consensus.SigningHandler enableEpochsHandler common.EnableEpochsHandler equivalentProofsPool consensus.EquivalentProofsPool + epochNotifier process.EpochNotifier } // GetAntiFloodHandler - @@ -295,6 +296,16 @@ func (ccm *ConsensusCoreMock) SetEquivalentProofsPool(proofPool consensus.Equiva ccm.equivalentProofsPool = proofPool } +// EpochNotifier - +func (ccm *ConsensusCoreMock) EpochNotifier() process.EpochNotifier { + return ccm.epochNotifier +} + +// SetEpochNotifier - +func (ccm *ConsensusCoreMock) SetEpochNotifier(epochNotifier process.EpochNotifier) { + ccm.epochNotifier = epochNotifier +} + // IsInterfaceNil returns true if there is no value under the interface func (ccm *ConsensusCoreMock) IsInterfaceNil() bool { return ccm == nil diff --git a/testscommon/consensus/mockTestInitializer.go b/testscommon/consensus/mockTestInitializer.go index 4cdd7174618..ba5db410a18 100644 --- a/testscommon/consensus/mockTestInitializer.go +++ b/testscommon/consensus/mockTestInitializer.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" epochstartmock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/pool" @@ -217,6 +218,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus signingHandler := &SigningHandlerStub{} enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} equivalentProofsPool := &dataRetriever.ProofsPoolMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} container := &ConsensusCoreMock{ blockChain: blockChain, @@ -244,6 +246,7 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus signingHandler: signingHandler, enableEpochsHandler: enableEpochsHandler, equivalentProofsPool: equivalentProofsPool, + epochNotifier: epochNotifier, } return container diff --git a/testscommon/consensus/sposWorkerMock.go b/testscommon/consensus/sposWorkerMock.go index 3a7e1ef384b..f3a6e1cded9 100644 --- a/testscommon/consensus/sposWorkerMock.go +++ b/testscommon/consensus/sposWorkerMock.go @@ -17,6 +17,7 @@ type SposWorkerMock struct { receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool, ) AddReceivedHeaderHandlerCalled func(handler func(data.HeaderHandler)) + RemoveAllReceivedHeaderHandlersCalled func() AddReceivedProofHandlerCalled func(handler func(proofHandler consensus.ProofHandler)) RemoveAllReceivedMessagesCallsCalled func() ProcessReceivedMessageCalled func(message p2p.MessageP2P) error @@ -48,6 +49,13 @@ func (sposWorkerMock *SposWorkerMock) AddReceivedHeaderHandler(handler func(data } } +// RemoveAllReceivedHeaderHandlers - +func (sposWorkerMock *SposWorkerMock) RemoveAllReceivedHeaderHandlers() { + if sposWorkerMock.RemoveAllReceivedHeaderHandlersCalled != nil { + sposWorkerMock.RemoveAllReceivedHeaderHandlersCalled() + } +} + func (sposWorkerMock *SposWorkerMock) AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) { if sposWorkerMock.AddReceivedProofHandlerCalled != nil { sposWorkerMock.AddReceivedProofHandlerCalled(handler) diff --git a/testscommon/dataRetriever/proofsPoolMock.go b/testscommon/dataRetriever/proofsPoolMock.go index 8154659a134..09a5d4646b9 100644 --- a/testscommon/dataRetriever/proofsPoolMock.go +++ b/testscommon/dataRetriever/proofsPoolMock.go @@ -11,6 +11,7 @@ type ProofsPoolMock struct { CleanupProofsBehindNonceCalled func(shardID uint32, nonce uint64) error GetProofCalled func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) HasProofCalled func(shardID uint32, headerHash []byte) bool + RegisterHandlerCalled func(handler func(headerProof data.HeaderProofHandler)) } // AddProof - @@ -49,6 +50,13 @@ func (p *ProofsPoolMock) HasProof(shardID uint32, headerHash []byte) bool { return false } +// RegisterHandler - +func (p *ProofsPoolMock) RegisterHandler(handler func(headerProof data.HeaderProofHandler)) { + if p.RegisterHandlerCalled != nil { + p.RegisterHandlerCalled(handler) + } +} + // IsInterfaceNil - func (p *ProofsPoolMock) IsInterfaceNil() bool { return p == nil diff --git a/testscommon/headerHandlerStub.go b/testscommon/headerHandlerStub.go index 00613c26d4d..733c8b5c167 100644 --- a/testscommon/headerHandlerStub.go +++ b/testscommon/headerHandlerStub.go @@ -40,6 +40,7 @@ type HeaderHandlerStub struct { SetLeaderSignatureCalled func(signature []byte) error GetPreviousProofCalled func() data.HeaderProofHandler SetPreviousProofCalled func(proof data.HeaderProofHandler) + GetShardIDCalled func() uint32 } // GetAccumulatedFees - @@ -91,6 +92,9 @@ func (hhs *HeaderHandlerStub) ShallowClone() data.HeaderHandler { // GetShardID - func (hhs *HeaderHandlerStub) GetShardID() uint32 { + if hhs.GetShardIDCalled != nil { + return hhs.GetShardIDCalled() + } return 1 } diff --git a/testscommon/processMocks/forkDetectorStub.go b/testscommon/processMocks/forkDetectorStub.go index 80ddc4d2ebf..9dbfe65c059 100644 --- a/testscommon/processMocks/forkDetectorStub.go +++ b/testscommon/processMocks/forkDetectorStub.go @@ -2,6 +2,7 @@ package processMocks import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,7 @@ type ForkDetectorStub struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) } // RestoreToGenesis - @@ -93,6 +95,13 @@ func (fdm *ForkDetectorStub) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorStub) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorStub) IsInterfaceNil() bool { return fdm == nil diff --git a/update/mock/epochStartNotifierStub.go b/update/mock/epochStartNotifierStub.go index 0a7b89387f5..96bb821a1f1 100644 --- a/update/mock/epochStartNotifierStub.go +++ b/update/mock/epochStartNotifierStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" )